ray/rllib/examples/serving/cartpole_server.py

87 lines
2.8 KiB
Python
Executable file

#!/usr/bin/env python
"""Example of running a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py --inference-mode=local|remote
"""
import argparse
import os
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.tune.logger import pretty_print
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint_{}.out"
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="DQN")
parser.add_argument(
"--framework", type=str, choices=["tf", "torch"], default="tf")
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
env = "CartPole-v0"
connector_config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
),
# Use a single worker process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
}
if args.run == "DQN":
# Example of using DQN (supports off-policy actions).
trainer = DQNTrainer(
env=env,
config=dict(
connector_config, **{
"exploration_config": {
"type": "EpsilonGreedy",
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 1000,
},
"learning_starts": 100,
"timesteps_per_iteration": 200,
"log_level": "INFO",
"framework": args.framework,
}))
elif args.run == "PPO":
# Example of using PPO (does NOT support off-policy actions).
trainer = PPOTrainer(
env=env,
config=dict(
connector_config, **{
"sample_batch_size": 1000,
"train_batch_size": 4000,
"framework": args.framework,
}))
else:
raise ValueError("--run must be DQN or PPO")
checkpoint_path = CHECKPOINT_FILE.format(args.run)
# Attempt to restore from checkpoint if possible.
if os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)
# Serving and training loop
while True:
print(pretty_print(trainer.train()))
checkpoint = trainer.save()
print("Last checkpoint", checkpoint)
with open(checkpoint_path, "w") as f:
f.write(checkpoint)