ray/rllib/examples/connectors/run_connector_policy.py

55 lines
1.4 KiB
Python

"""This example script shows how to load a connector enabled policy,
and use it in a serving/inference setting.
"""
import argparse
import gym
from ray.rllib.utils.policy import (
load_policies_from_checkpoint,
local_policy_inference,
)
parser = argparse.ArgumentParser()
# This should a checkpoint created with connectors enabled.
parser.add_argument(
"--checkpoint_file",
help="Path to an RLlib checkpoint file.",
)
parser.add_argument(
"--policy_id",
default="default_policy",
help="ID of policy to load.",
)
args = parser.parse_args()
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
def run():
# Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
policy = policies[args.policy_id]
# Run CartPole.
env = gym.make("CartPole-v0")
obs = env.reset()
done = False
step = 0
while not done:
step += 1
# Use local_policy_inference() to run inference, so we do not have to
# provide policy states or extra fetch dictionaries.
policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs)
assert len(policy_outputs) == 1
action, _, _ = policy_outputs[0]
print(f"step {step}", obs, action)
# Step environment forward one more step.
obs, _, done, _ = env.step(action[0])
if __name__ == "__main__":
run()