2022-02-03 09:13:04 +01:00
|
|
|
---
|
|
|
|
jupytext:
|
|
|
|
text_representation:
|
|
|
|
extension: .md
|
|
|
|
format_name: myst
|
|
|
|
kernelspec:
|
|
|
|
display_name: Python 3
|
|
|
|
language: python
|
|
|
|
name: python3
|
|
|
|
---
|
|
|
|
|
|
|
|
(serve-rllib-tutorial)=
|
|
|
|
|
|
|
|
# Serving RLlib Models
|
|
|
|
|
|
|
|
In this guide, we will train and deploy a simple Ray RLlib model.
|
|
|
|
In particular, we show:
|
|
|
|
|
|
|
|
- How to train and store an RLlib model.
|
|
|
|
- How to load this model from a checkpoint.
|
|
|
|
- How to parse the JSON request and evaluate the payload in RLlib.
|
|
|
|
|
|
|
|
```{margin}
|
2022-07-09 02:58:21 +08:00
|
|
|
Check out the [Key Concepts](serve-key-concepts) page to learn more general information about Ray Serve.
|
2022-02-03 09:13:04 +01:00
|
|
|
```
|
|
|
|
|
|
|
|
We will train and checkpoint a simple PPO model with the `CartPole-v0` environment from `gym`.
|
|
|
|
In this tutorial we simply write to local disk, but in production you might want to consider using a cloud
|
|
|
|
storage solution like S3 or a shared file system.
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
Let's get started by defining a `PPO` instance, training it for one iteration and then creating a checkpoint:
|
2022-02-03 09:13:04 +01:00
|
|
|
|
|
|
|
```{code-cell} python3
|
|
|
|
:tags: [remove-output]
|
|
|
|
|
|
|
|
import ray
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.ppo as ppo
|
2022-02-03 09:13:04 +01:00
|
|
|
from ray import serve
|
|
|
|
|
|
|
|
def train_ppo_model():
|
2022-06-04 07:35:24 +02:00
|
|
|
# Configure our PPO algorithm.
|
|
|
|
config = ppo.PPOConfig()\
|
|
|
|
.framework("torch")\
|
|
|
|
.rollouts(num_rollout_workers=0)
|
2022-06-13 11:43:38 +02:00
|
|
|
# Create a `PPO` instance from the config.
|
|
|
|
algo = config.build(env="CartPole-v0")
|
2022-06-04 07:35:24 +02:00
|
|
|
# Train for one iteration.
|
2022-06-13 11:43:38 +02:00
|
|
|
algo.train()
|
|
|
|
# Save state of the trained Algorithm in a checkpoint.
|
|
|
|
algo.save("/tmp/rllib_checkpoint")
|
2022-02-03 09:13:04 +01:00
|
|
|
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path = train_ppo_model()
|
|
|
|
```
|
|
|
|
|
|
|
|
You create deployments with Ray Serve by using the `@serve.deployment` on a class that implements two methods:
|
|
|
|
|
2022-05-20 10:56:34 -07:00
|
|
|
- The `__init__` call creates the deployment instance and loads your data once.
|
2022-06-13 11:43:38 +02:00
|
|
|
In the below example we restore our `PPO` Algorithm from the checkpoint we just created.
|
2022-02-03 09:13:04 +01:00
|
|
|
- The `__call__` method will be invoked every request.
|
|
|
|
For each incoming request, this method has access to a `request` object,
|
|
|
|
which is a [Starlette Request](https://www.starlette.io/requests/).
|
|
|
|
|
|
|
|
We can load the request body as a JSON object and, assuming there is a key called `observation`,
|
|
|
|
in your deployment you can use `request.json()["observation"]` to retrieve observations (`obs`) and
|
2022-06-13 11:43:38 +02:00
|
|
|
pass them into the restored `Algorithm` using the `compute_single_action` method.
|
2022-02-03 09:13:04 +01:00
|
|
|
|
|
|
|
|
|
|
|
```{code-cell} python3
|
|
|
|
:tags: [hide-output]
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
|
|
|
|
|
|
|
@serve.deployment(route_prefix="/cartpole-ppo")
|
|
|
|
class ServePPOModel:
|
|
|
|
def __init__(self, checkpoint_path) -> None:
|
2022-06-13 11:43:38 +02:00
|
|
|
# Re-create the originally used config.
|
2022-06-04 07:35:24 +02:00
|
|
|
config = ppo.PPOConfig()\
|
|
|
|
.framework("torch")\
|
|
|
|
.rollouts(num_rollout_workers=0)
|
2022-06-13 11:43:38 +02:00
|
|
|
# Build the Algorithm instance using the config.
|
|
|
|
self.algorithm = config.build(env="CartPole-v0")
|
|
|
|
# Restore the algo's state from the checkpoint.
|
|
|
|
self.algorithm.restore(checkpoint_path)
|
2022-02-03 09:13:04 +01:00
|
|
|
|
|
|
|
async def __call__(self, request: Request):
|
|
|
|
json_input = await request.json()
|
|
|
|
obs = json_input["observation"]
|
|
|
|
|
2022-06-13 11:43:38 +02:00
|
|
|
action = self.algorithm.compute_single_action(obs)
|
2022-02-03 09:13:04 +01:00
|
|
|
return {"action": int(action)}
|
|
|
|
```
|
|
|
|
|
|
|
|
:::{tip}
|
2022-06-13 11:43:38 +02:00
|
|
|
Although we used a single input and `Algorithm.compute_single_action(...)` here, you
|
2022-05-10 14:04:17 -07:00
|
|
|
can process a batch of input using Ray Serve's [batching](serve-batching) feature
|
2022-06-13 11:43:38 +02:00
|
|
|
and use `Algorithm.compute_actions(...)` to process a batch of inputs.
|
2022-02-03 09:13:04 +01:00
|
|
|
:::
|
|
|
|
|
|
|
|
Now that we've defined our `ServePPOModel` service, let's deploy it to Ray Serve.
|
|
|
|
The deployment will be exposed through the `/cartpole-ppo` route.
|
|
|
|
|
|
|
|
```{code-cell} python3
|
|
|
|
:tags: [hide-output]
|
|
|
|
serve.start()
|
|
|
|
ServePPOModel.deploy(checkpoint_path)
|
|
|
|
```
|
|
|
|
|
|
|
|
Note that the `checkpoint_path` that we passed to the `deploy()` method will be passed to
|
|
|
|
the `__init__` method of the `ServePPOModel` class that we defined above.
|
|
|
|
|
|
|
|
Now that the model is deployed, let's query it!
|
|
|
|
|
|
|
|
```{code-cell} python3
|
|
|
|
import gym
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(5):
|
|
|
|
env = gym.make("CartPole-v0")
|
|
|
|
obs = env.reset()
|
|
|
|
|
|
|
|
print(f"-> Sending observation {obs}")
|
|
|
|
resp = requests.get(
|
|
|
|
"http://localhost:8000/cartpole-ppo", json={"observation": obs.tolist()}
|
|
|
|
)
|
|
|
|
print(f"<- Received response {resp.json()}")
|
|
|
|
```
|
|
|
|
|
|
|
|
You should see output like this (`observation` values will differ):
|
|
|
|
|
|
|
|
```text
|
|
|
|
<- Received response {'action': 1}
|
|
|
|
-> Sending observation [0.04228249 0.02289503 0.00690076 0.03095441]
|
|
|
|
<- Received response {'action': 0}
|
|
|
|
-> Sending observation [ 0.04819471 -0.04702759 -0.00477937 -0.00735569]
|
|
|
|
<- Received response {'action': 0}
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
:::{note}
|
|
|
|
In this example the client used the `requests` library to send a request to the server.
|
|
|
|
We defined a `json` object with an `observation` key and a Python list of observations (`obs.tolist()`).
|
|
|
|
Since `obs = env.reset()` is a `numpy.ndarray`, we used `tolist()` for conversion.
|
|
|
|
On the server side, we used `obs = json_input["observation"]` to retrieve the observations again, which has `list` type.
|
2022-06-13 11:43:38 +02:00
|
|
|
In the simple case of an RLlib algorithm with a simple observation space, it's possible to pass this
|
|
|
|
`obs` list to the `Algorithm.compute_single_action(...)` method.
|
|
|
|
We could also have created a `numpy` array from it first and then passed it into the `Algorithm`.
|
2022-02-03 09:13:04 +01:00
|
|
|
|
|
|
|
In more complex cases with tuple or dict observation spaces, you will have to do some preprocessing of
|
2022-06-13 11:43:38 +02:00
|
|
|
your `json_input` before passing it to your `Algorithm` instance.
|
2022-02-03 09:13:04 +01:00
|
|
|
The exact way to process your input depends on how you serialize your observations on the client.
|
|
|
|
:::
|
|
|
|
|
|
|
|
```{code-cell} python3
|
|
|
|
:tags: [remove-cell]
|
|
|
|
ray.shutdown()
|
2022-07-09 02:58:21 +08:00
|
|
|
```
|