ray/rllib/examples/serving/cartpole_server.py

85 lines
2.8 KiB
Python
Executable file

#!/usr/bin/env python
"""Example of running a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py --inference-mode=local|remote
"""
import argparse
import os
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.tune.logger import pretty_print
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint_{}.out"
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="DQN")
parser.add_argument(
"--framework", type=str, choices=["tf", "torch"], default="tf")
parser.add_argument(
"--no-restore",
action="store_true",
help="Do not restore from a previously saved checkpoint (location of "
"which is saved in `last_checkpoint_[algo-name].out`).")
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
env = "CartPole-v0"
connector_config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
),
# Use a single worker process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
}
if args.run == "DQN":
# Example of using DQN (supports off-policy actions).
trainer = DQNTrainer(
env=env,
config=dict(
connector_config, **{
"learning_starts": 100,
"timesteps_per_iteration": 200,
"framework": args.framework,
}))
elif args.run == "PPO":
# Example of using PPO (does NOT support off-policy actions).
trainer = PPOTrainer(
env=env,
config=dict(
connector_config, **{
"rollout_fragment_length": 1000,
"train_batch_size": 4000,
"framework": args.framework,
}))
else:
raise ValueError("--run must be DQN or PPO")
checkpoint_path = CHECKPOINT_FILE.format(args.run)
# Attempt to restore from checkpoint, if possible.
if not args.no_restore and os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)
# Serving and training loop.
while True:
print(pretty_print(trainer.train()))
checkpoint = trainer.save()
print("Last checkpoint", checkpoint)
with open(checkpoint_path, "w") as f:
f.write(checkpoint)