2020-03-20 12:43:57 -07:00
|
|
|
#!/usr/bin/env python
|
2021-06-23 09:09:01 +02:00
|
|
|
"""
|
|
|
|
Example of running an RLlib policy server, allowing connections from
|
|
|
|
external environment running clients. The server listens on
|
|
|
|
(a simple CartPole env
|
|
|
|
in this case) against an RLlib policy server listening on one or more
|
|
|
|
HTTP-speaking ports. See `cartpole_client.py` in this same directory for how
|
|
|
|
to start any number of clients (after this server has been started).
|
|
|
|
|
|
|
|
This script will not create any actual env to illustrate that RLlib can
|
|
|
|
run w/o needing an internalized environment.
|
|
|
|
|
|
|
|
Setup:
|
|
|
|
1) Start this server:
|
|
|
|
$ python cartpole_server.py --num-workers --[other options]
|
|
|
|
Use --help for help.
|
|
|
|
2) Run n policy clients:
|
|
|
|
See `cartpole_client.py` on how to do this.
|
2018-06-20 13:22:39 -07:00
|
|
|
|
2021-06-23 09:09:01 +02:00
|
|
|
The `num-workers` setting will allow you to distribute the incoming feed over n
|
|
|
|
listen sockets (in this example, between 9900 and 990n with n=worker_idx-1).
|
|
|
|
You may connect more than one policy client to any open listen port.
|
2018-06-20 13:22:39 -07:00
|
|
|
"""
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
import argparse
|
2021-06-23 09:09:01 +02:00
|
|
|
import gym
|
2018-06-20 13:22:39 -07:00
|
|
|
import os
|
|
|
|
|
|
|
|
import ray
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.dqn import DQNTrainer
|
2020-03-20 12:43:57 -07:00
|
|
|
from ray.rllib.agents.ppo import PPOTrainer
|
|
|
|
from ray.rllib.env.policy_server_input import PolicyServerInput
|
2021-02-08 15:02:19 +01:00
|
|
|
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
|
2018-06-20 13:22:39 -07:00
|
|
|
from ray.tune.logger import pretty_print
|
|
|
|
|
|
|
|
SERVER_ADDRESS = "localhost"
|
2021-06-23 09:09:01 +02:00
|
|
|
# In this example, the user can run the policy server with
|
|
|
|
# n workers, opening up listen ports 9900 - 990n (n = num_workers - 1)
|
|
|
|
# to each of which different clients may connect.
|
|
|
|
SERVER_BASE_PORT = 9900 # + worker-idx - 1
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
CHECKPOINT_FILE = "last_checkpoint_{}.out"
|
2018-06-20 13:22:39 -07:00
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
parser = argparse.ArgumentParser()
|
2021-06-23 09:09:01 +02:00
|
|
|
parser.add_argument("--run", type=str, choices=["DQN", "PPO"], default="DQN")
|
2020-05-27 16:19:13 +02:00
|
|
|
parser.add_argument(
|
2021-05-18 13:18:12 +02:00
|
|
|
"--framework",
|
|
|
|
choices=["tf", "torch"],
|
|
|
|
default="tf",
|
|
|
|
help="The DL framework specifier.")
|
2021-01-14 20:44:26 +01:00
|
|
|
parser.add_argument(
|
|
|
|
"--no-restore",
|
|
|
|
action="store_true",
|
|
|
|
help="Do not restore from a previously saved checkpoint (location of "
|
|
|
|
"which is saved in `last_checkpoint_[algo-name].out`).")
|
2021-06-23 09:09:01 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--num-workers",
|
|
|
|
type=int,
|
|
|
|
default=2,
|
|
|
|
help="The number of workers to use. Each worker will create "
|
|
|
|
"its own listening socket for incoming experiences.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--chatty-callbacks",
|
|
|
|
action="store_true",
|
|
|
|
help="Activates info-messages for different events on "
|
|
|
|
"server/client (episode steps, postprocessing, etc..).")
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-20 12:43:57 -07:00
|
|
|
args = parser.parse_args()
|
2018-06-20 13:22:39 -07:00
|
|
|
ray.init()
|
|
|
|
|
2021-06-23 09:09:01 +02:00
|
|
|
# `InputReader` generator (returns None if no input reader is needed on
|
|
|
|
# the respective worker).
|
|
|
|
def _input(ioctx):
|
|
|
|
# We are remote worker or we are local worker with num_workers=0:
|
|
|
|
# Create a PolicyServerInput.
|
|
|
|
if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
|
|
|
|
return PolicyServerInput(
|
|
|
|
ioctx, SERVER_ADDRESS, SERVER_BASE_PORT + ioctx.worker_index -
|
|
|
|
(1 if ioctx.worker_index > 0 else 0))
|
|
|
|
# No InputReader (PolicyServerInput) needed.
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Trainer config. Note that this config is sent to the client only in case
|
|
|
|
# the client needs to create its own policy copy for local inference.
|
|
|
|
config = {
|
|
|
|
# Indicate that the Trainer we setup here doesn't need an actual env.
|
|
|
|
# Allow spaces to be determined by user (see below).
|
|
|
|
"env": None,
|
|
|
|
|
|
|
|
# TODO: (sven) make these settings unnecessary and get the information
|
|
|
|
# about the env spaces from the client.
|
|
|
|
"observation_space": gym.spaces.Box(
|
|
|
|
float("-inf"), float("inf"), (4, )),
|
|
|
|
"action_space": gym.spaces.Discrete(2),
|
|
|
|
|
|
|
|
# Use the `PolicyServerInput` to generate experiences.
|
|
|
|
"input": _input,
|
|
|
|
# Use n worker processes to listen on different ports.
|
|
|
|
"num_workers": args.num_workers,
|
2020-03-20 12:43:57 -07:00
|
|
|
# Disable OPE, since the rollouts are coming from online clients.
|
|
|
|
"input_evaluation": [],
|
2021-06-23 09:09:01 +02:00
|
|
|
# Create a "chatty" client/server or not.
|
|
|
|
"callbacks": MyCallbacks if args.chatty_callbacks else None,
|
2020-03-20 12:43:57 -07:00
|
|
|
}
|
|
|
|
|
2021-06-23 09:09:01 +02:00
|
|
|
# DQN.
|
2020-03-20 12:43:57 -07:00
|
|
|
if args.run == "DQN":
|
|
|
|
# Example of using DQN (supports off-policy actions).
|
|
|
|
trainer = DQNTrainer(
|
|
|
|
config=dict(
|
2021-06-23 09:09:01 +02:00
|
|
|
config, **{
|
2020-03-20 12:43:57 -07:00
|
|
|
"learning_starts": 100,
|
|
|
|
"timesteps_per_iteration": 200,
|
2021-06-23 09:09:01 +02:00
|
|
|
"model": {
|
|
|
|
"fcnet_hiddens": [64],
|
|
|
|
"fcnet_activation": "linear",
|
|
|
|
},
|
|
|
|
"n_step": 3,
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": args.framework,
|
2020-03-20 12:43:57 -07:00
|
|
|
}))
|
2021-06-23 09:09:01 +02:00
|
|
|
# PPO.
|
|
|
|
else:
|
2020-03-20 12:43:57 -07:00
|
|
|
# Example of using PPO (does NOT support off-policy actions).
|
|
|
|
trainer = PPOTrainer(
|
|
|
|
config=dict(
|
2021-06-23 09:09:01 +02:00
|
|
|
config, **{
|
2020-09-06 10:58:00 +02:00
|
|
|
"rollout_fragment_length": 1000,
|
2020-03-20 12:43:57 -07:00
|
|
|
"train_batch_size": 4000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": args.framework,
|
2020-03-20 12:43:57 -07:00
|
|
|
}))
|
|
|
|
|
|
|
|
checkpoint_path = CHECKPOINT_FILE.format(args.run)
|
2018-06-20 13:22:39 -07:00
|
|
|
|
2021-01-14 20:44:26 +01:00
|
|
|
# Attempt to restore from checkpoint, if possible.
|
|
|
|
if not args.no_restore and os.path.exists(checkpoint_path):
|
2020-03-20 12:43:57 -07:00
|
|
|
checkpoint_path = open(checkpoint_path).read()
|
2018-06-20 13:22:39 -07:00
|
|
|
print("Restoring from checkpoint path", checkpoint_path)
|
2020-03-20 12:43:57 -07:00
|
|
|
trainer.restore(checkpoint_path)
|
2018-06-20 13:22:39 -07:00
|
|
|
|
2021-01-14 20:44:26 +01:00
|
|
|
# Serving and training loop.
|
2018-06-20 13:22:39 -07:00
|
|
|
while True:
|
2020-03-20 12:43:57 -07:00
|
|
|
print(pretty_print(trainer.train()))
|
|
|
|
checkpoint = trainer.save()
|
|
|
|
print("Last checkpoint", checkpoint)
|
|
|
|
with open(checkpoint_path, "w") as f:
|
|
|
|
f.write(checkpoint)
|