mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
"""This example script shows how one can use Ray Serve to serve an already
|
|
trained RLlib Policy (and its model) to serve action computations.
|
|
|
|
For a complete tutorial, also see:
|
|
https://docs.ray.io/en/master/serve/tutorials/rllib.html
|
|
"""
|
|
|
|
import argparse
|
|
import gym
|
|
import requests
|
|
from starlette.requests import Request
|
|
|
|
import ray
|
|
import ray.rllib.agents.dqn as dqn
|
|
from ray.rllib.env.wrappers.atari_wrappers import FrameStack, WarpFrame
|
|
from ray import serve
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default="tf",
|
|
help="The DL framework specifier.")
|
|
parser.add_argument("--train-iters", type=int, default=1)
|
|
parser.add_argument("--no-render", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
class ServeRLlibPolicy:
|
|
"""Callable class used by Ray Serve to handle async requests.
|
|
|
|
All the necessary serving logic is implemented in here:
|
|
- Creation and restoring of the (already trained) RLlib Trainer.
|
|
- Calls to trainer.compute_action upon receiving an action request
|
|
(with a current observation).
|
|
"""
|
|
|
|
def __init__(self, config, checkpoint_path):
|
|
# Create the Trainer.
|
|
self.trainer = dqn.DQNTrainer(config=config)
|
|
# Load an already trained state for the trainer.
|
|
self.trainer.restore(checkpoint_path)
|
|
|
|
async def __call__(self, request: Request):
|
|
json_input = await request.json()
|
|
|
|
# Compute and return the action for the given observation.
|
|
obs = json_input["observation"]
|
|
action = self.trainer.compute_action(obs)
|
|
|
|
return {"action": int(action)}
|
|
|
|
|
|
def train_rllib_policy(config):
|
|
"""Trains a DQNTrainer on MsPacman-v0 for n iterations.
|
|
|
|
Saves the trained Trainer to disk and returns the checkpoint path.
|
|
|
|
Returns:
|
|
str: The saved checkpoint to restore the trainer DQNTrainer from.
|
|
"""
|
|
# Create trainer from config.
|
|
trainer = dqn.DQNTrainer(config=config)
|
|
|
|
# Train for n iterations, then save.
|
|
for _ in range(args.train_iters):
|
|
print(trainer.train())
|
|
return trainer.save()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Config for the served RLlib Policy/Trainer.
|
|
config = {
|
|
"framework": args.framework,
|
|
# local mode -> local env inside Trainer not needed!
|
|
"num_workers": 0,
|
|
"env": "MsPacman-v0",
|
|
}
|
|
|
|
# Train the policy for some time, then save it and get the checkpoint path.
|
|
checkpoint_path = train_rllib_policy(config)
|
|
|
|
ray.init(num_cpus=8)
|
|
|
|
# Start Ray serve (create the RLlib Policy service defined by
|
|
# our `ServeRLlibPolicy` class above).
|
|
client = serve.start()
|
|
client.create_backend("backend", ServeRLlibPolicy, config, checkpoint_path)
|
|
client.create_endpoint(
|
|
"endpoint", backend="backend", route="/mspacman-rllib-policy")
|
|
|
|
# Create the environment that we would like to receive
|
|
# served actions for.
|
|
env = FrameStack(WarpFrame(gym.make("MsPacman-v0"), 84), 4)
|
|
obs = env.reset()
|
|
|
|
while True:
|
|
print("-> Requesting action for obs ...")
|
|
# Send a request to serve.
|
|
resp = requests.get(
|
|
"http://localhost:8000/mspacman-rllib-policy",
|
|
json={"observation": obs.tolist()})
|
|
response = resp.json()
|
|
print("<- Received response {}".format(response))
|
|
|
|
# Apply the action in the env.
|
|
action = response["action"]
|
|
obs, reward, done, _ = env.step(action)
|
|
|
|
# If episode done -> reset to get initial observation of new episode.
|
|
if done:
|
|
obs = env.reset()
|
|
|
|
# Render if necessary.
|
|
if not args.no_render:
|
|
env.render()
|