ray/rllib/examples/serve_and_rllib.py

118 lines
3.6 KiB
Python

"""This example script shows how one can use Ray Serve to serve an already
trained RLlib Policy (and its model) to serve action computations.
For a complete tutorial, also see:
https://docs.ray.io/en/master/serve/tutorials/rllib.html
"""
import argparse
import gym
import requests
from starlette.requests import Request
import ray
import ray.rllib.agents.dqn as dqn
from ray.rllib.env.wrappers.atari_wrappers import FrameStack, WarpFrame
from ray import serve
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--train-iters", type=int, default=1)
parser.add_argument("--no-render", action="store_true")
args = parser.parse_args()
class ServeRLlibPolicy:
"""Callable class used by Ray Serve to handle async requests.
All the necessary serving logic is implemented in here:
- Creation and restoring of the (already trained) RLlib Trainer.
- Calls to trainer.compute_action upon receiving an action request
(with a current observation).
"""
def __init__(self, config, checkpoint_path):
# Create the Trainer.
self.trainer = dqn.DQNTrainer(config=config)
# Load an already trained state for the trainer.
self.trainer.restore(checkpoint_path)
async def __call__(self, request: Request):
json_input = await request.json()
# Compute and return the action for the given observation.
obs = json_input["observation"]
action = self.trainer.compute_action(obs)
return {"action": int(action)}
def train_rllib_policy(config):
"""Trains a DQNTrainer on MsPacman-v0 for n iterations.
Saves the trained Trainer to disk and returns the checkpoint path.
Returns:
str: The saved checkpoint to restore the trainer DQNTrainer from.
"""
# Create trainer from config.
trainer = dqn.DQNTrainer(config=config)
# Train for n iterations, then save.
for _ in range(args.train_iters):
print(trainer.train())
return trainer.save()
if __name__ == "__main__":
# Config for the served RLlib Policy/Trainer.
config = {
"framework": args.framework,
# local mode -> local env inside Trainer not needed!
"num_workers": 0,
"env": "MsPacman-v0",
}
# Train the policy for some time, then save it and get the checkpoint path.
checkpoint_path = train_rllib_policy(config)
ray.init(num_cpus=8)
# Start Ray serve (create the RLlib Policy service defined by
# our `ServeRLlibPolicy` class above).
client = serve.start()
client.create_backend("backend", ServeRLlibPolicy, config, checkpoint_path)
client.create_endpoint(
"endpoint", backend="backend", route="/mspacman-rllib-policy")
# Create the environment that we would like to receive
# served actions for.
env = FrameStack(WarpFrame(gym.make("MsPacman-v0"), 84), 4)
obs = env.reset()
while True:
print("-> Requesting action for obs ...")
# Send a request to serve.
resp = requests.get(
"http://localhost:8000/mspacman-rllib-policy",
json={"observation": obs.tolist()})
response = resp.json()
print("<- Received response {}".format(response))
# Apply the action in the env.
action = response["action"]
obs, reward, done, _ = env.step(action)
# If episode done -> reset to get initial observation of new episode.
if done:
obs = env.reset()
# Render if necessary.
if not args.no_render:
env.render()