{
"cells": [
{
"cell_type": "markdown",
"id": "0d401e53",
"metadata": {},
"source": [
"# Serving reinforcement learning policy models\n",
"In this example, we train a reinforcement learning model and serve it\n",
"using Ray Serve.\n",
"\n",
"We then instantiate an environment and step through it by querying the served model\n",
"for actions via HTTP."
]
},
{
"cell_type": "markdown",
"id": "cf8cd121",
"metadata": {},
"source": [
"Let's start with installing our dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8dc6f862",
"metadata": {},
"outputs": [],
"source": [
"!pip install -qU \"ray[rllib,serve]\" gym"
]
},
{
"cell_type": "markdown",
"id": "13518458",
"metadata": {},
"source": [
"Now we can run some imports:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c4a621ee",
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"import requests\n",
"\n",
"from ray.air.checkpoint import Checkpoint\n",
"from ray.air.config import RunConfig\n",
"from ray.train.rl.rl_trainer import RLTrainer\n",
"from ray.train.rl.rl_predictor import RLPredictor\n",
"from ray.air.result import Result\n",
"from ray.serve import PredictorDeployment\n",
"from ray import serve\n",
"from ray.tune.tuner import Tuner"
]
},
{
"cell_type": "markdown",
"id": "2781f448",
"metadata": {},
"source": [
"Since we'll be serving a reinforcement learning policy, we need to train one first. Thus we define a simple training function which will kick off online reinforcement learning of a PPO agent on the `CartPole-v0` environment."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0f247ac6",
"metadata": {},
"outputs": [],
"source": [
"def train_rl_ppo_online(num_workers: int, use_gpu: bool = False) -> Result:\n",
" print(\"Starting online training\")\n",
" trainer = RLTrainer(\n",
" run_config=RunConfig(stop={\"training_iteration\": 5}),\n",
" scaling_config={\n",
" \"num_workers\": num_workers,\n",
" \"use_gpu\": use_gpu,\n",
" },\n",
" algorithm=\"PPO\",\n",
" config={\n",
" \"env\": \"CartPole-v0\",\n",
" \"framework\": \"tf\",\n",
" },\n",
" )\n",
" # Todo (krfricke/xwjiang): Enable checkpoint config in RunConfig\n",
" # result = trainer.fit()\n",
" tuner = Tuner(\n",
" trainer,\n",
" _tuner_kwargs={\"checkpoint_at_end\": True},\n",
" )\n",
" result = tuner.fit()[0]\n",
" return result"
]
},
{
"cell_type": "markdown",
"id": "00cb47c3",
"metadata": {},
"source": [
"Once we obtained a trained checkpoint, we will want to serve it using Ray Serve:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "caddaf3a",
"metadata": {},
"outputs": [],
"source": [
"def serve_rl_model(checkpoint: Checkpoint, name=\"RLModel\") -> str:\n",
" \"\"\"Serve a RL model and return deployment URI.\n",
"\n",
" This function will start Ray Serve and deploy a model wrapper\n",
" that loads the RL checkpoint into a RLPredictor.\n",
" \"\"\"\n",
" serve.start(detached=True)\n",
" deployment = PredictorDeployment.options(name=name)\n",
" deployment.deploy(RLPredictor, checkpoint)\n",
" return deployment.url"
]
},
{
"cell_type": "markdown",
"id": "092f94ee",
"metadata": {},
"source": [
"And to make sure everything works well, we can kick off an evaluation run on a fresh environment. This will query the served policy model to obtain actions using HTTP."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4579efd2",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_served_policy(endpoint_uri: str, num_episodes: int = 3) -> list:\n",
" \"\"\"Evaluate a served RL policy on a local environment.\n",
"\n",
" This function will create an RL environment and step through it.\n",
" To obtain the actions, it will query the deployed RL model.\n",
" \"\"\"\n",
" env = gym.make(\"CartPole-v0\")\n",
"\n",
" rewards = []\n",
" for i in range(num_episodes):\n",
" obs = env.reset()\n",
" reward = 0.0\n",
" done = False\n",
" while not done:\n",
" action = query_action(endpoint_uri, obs)\n",
" obs, r, done, _ = env.step(action)\n",
" reward += r\n",
" rewards.append(reward)\n",
"\n",
" return rewards\n",
"\n",
"\n",
"def query_action(endpoint_uri: str, obs: np.ndarray):\n",
" \"\"\"Perform inference on a served RL model.\n",
"\n",
" This will send a HTTP request to the Ray Serve endpoint of the served\n",
" RL policy model and return the result.\n",
" \"\"\"\n",
" action_dict = requests.post(endpoint_uri, json={\"array\": obs.tolist()}).json()\n",
" return action_dict"
]
},
{
"cell_type": "markdown",
"id": "757489e1",
"metadata": {},
"source": [
"Let's put it all together. First, we train the model:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1ceaedc9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-19 14:19:32,791\tWARNING deprecation.py:47 -- DeprecationWarning: `ray.rllib.execution.buffers` has been deprecated. Use `ray.rllib.utils.replay_buffers` instead. This will raise an error in the future!\n",
"2022-05-19 14:19:32,816\tWARNING deprecation.py:47 -- DeprecationWarning: `ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG` has been deprecated. Use `ray.rllib.agents.dqn.dqn.DQNConfig(...)` instead. This will raise an error in the future!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting online training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-19 14:19:35,724\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8269\u001b[39m\u001b[22m\n"
]
},
{
"data": {
"text/html": [
"== Status ==
Current time: 2022-05-19 14:20:14 (running for 00:00:36.01)
Memory usage on this node: 9.7/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.44 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/AIRPPOTrainer_2022-05-19_14-19-32
Number of trials: 1/1 (1 TERMINATED)
Trial name | status | loc | iter | total time (s) | ts | reward | episode_reward_max | episode_reward_min | episode_len_mean |
---|---|---|---|---|---|---|---|---|---|
AIRPPOTrainer_55884_00000 | TERMINATED | 127.0.0.1:15610 | 5 | 16.4897 | 20000 | 131.8 | 200 | 16 | 131.8 |