2020-03-20 12:43:57 -07:00
|
|
|
#!/usr/bin/env python
|
2018-06-20 13:22:39 -07:00
|
|
|
"""Example of running a policy server. Copy this file for your use case.
|
|
|
|
|
|
|
|
To try this out, in two separate shells run:
|
|
|
|
$ python cartpole_server.py
|
2020-03-20 12:43:57 -07:00
|
|
|
$ python cartpole_client.py --inference-mode=local|remote
|
2018-06-20 13:22:39 -07:00
|
|
|
"""
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
import argparse
|
2018-06-20 13:22:39 -07:00
|
|
|
import os
|
|
|
|
|
|
|
|
import ray
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.dqn import DQNTrainer
|
2020-03-20 12:43:57 -07:00
|
|
|
from ray.rllib.agents.ppo import PPOTrainer
|
|
|
|
from ray.rllib.env.policy_server_input import PolicyServerInput
|
2018-06-20 13:22:39 -07:00
|
|
|
from ray.tune.logger import pretty_print
|
|
|
|
|
|
|
|
SERVER_ADDRESS = "localhost"
|
2018-11-12 16:31:27 -08:00
|
|
|
SERVER_PORT = 9900
|
2020-03-20 12:43:57 -07:00
|
|
|
CHECKPOINT_FILE = "last_checkpoint_{}.out"
|
2018-06-20 13:22:39 -07:00
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--run", type=str, default="DQN")
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-20 12:43:57 -07:00
|
|
|
args = parser.parse_args()
|
2018-06-20 13:22:39 -07:00
|
|
|
ray.init()
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
env = "CartPole-v0"
|
|
|
|
connector_config = {
|
|
|
|
# Use the connector server to generate experiences.
|
|
|
|
"input": (
|
|
|
|
lambda ioctx: PolicyServerInput( \
|
|
|
|
ioctx, SERVER_ADDRESS, SERVER_PORT)
|
|
|
|
),
|
|
|
|
# Use a single worker process to run the server.
|
|
|
|
"num_workers": 0,
|
|
|
|
# Disable OPE, since the rollouts are coming from online clients.
|
|
|
|
"input_evaluation": [],
|
|
|
|
}
|
|
|
|
|
|
|
|
if args.run == "DQN":
|
|
|
|
# Example of using DQN (supports off-policy actions).
|
|
|
|
trainer = DQNTrainer(
|
|
|
|
env=env,
|
|
|
|
config=dict(
|
|
|
|
connector_config, **{
|
|
|
|
"exploration_config": {
|
|
|
|
"type": "EpsilonGreedy",
|
|
|
|
"initial_epsilon": 1.0,
|
|
|
|
"final_epsilon": 0.02,
|
|
|
|
"epsilon_timesteps": 1000,
|
|
|
|
},
|
|
|
|
"learning_starts": 100,
|
|
|
|
"timesteps_per_iteration": 200,
|
|
|
|
"log_level": "INFO",
|
|
|
|
}))
|
|
|
|
elif args.run == "PPO":
|
|
|
|
# Example of using PPO (does NOT support off-policy actions).
|
|
|
|
trainer = PPOTrainer(
|
|
|
|
env=env,
|
|
|
|
config=dict(
|
|
|
|
connector_config, **{
|
|
|
|
"sample_batch_size": 1000,
|
|
|
|
"train_batch_size": 4000,
|
|
|
|
}))
|
|
|
|
else:
|
|
|
|
raise ValueError("--run must be DQN or PPO")
|
|
|
|
|
|
|
|
checkpoint_path = CHECKPOINT_FILE.format(args.run)
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
# Attempt to restore from checkpoint if possible.
|
2020-03-20 12:43:57 -07:00
|
|
|
if os.path.exists(checkpoint_path):
|
|
|
|
checkpoint_path = open(checkpoint_path).read()
|
2018-06-20 13:22:39 -07:00
|
|
|
print("Restoring from checkpoint path", checkpoint_path)
|
2020-03-20 12:43:57 -07:00
|
|
|
trainer.restore(checkpoint_path)
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
# Serving and training loop
|
|
|
|
while True:
|
2020-03-20 12:43:57 -07:00
|
|
|
print(pretty_print(trainer.train()))
|
|
|
|
checkpoint = trainer.save()
|
|
|
|
print("Last checkpoint", checkpoint)
|
|
|
|
with open(checkpoint_path, "w") as f:
|
|
|
|
f.write(checkpoint)
|