"cells": [
"cell_type": "markdown",
"id": "3471e19a",
"metadata": {},
"source": [
"# Online reinforcement learning with Ray AIR\n",
"In this example, we'll train a reinforcement learning agent using online training.\n",
"Online training means that the data from the environment is sampled while we are running the algorithm. In contrast, offline training uses data that has been stored on disk before."
"cell_type": "markdown",
"id": "f5083f08",
"metadata": {},
"source": [
"Let's start with installing our dependencies:"
"cell_type": "code",
"execution_count": 1,
"id": "01f914d2",
"metadata": {},
"outputs": [],
"source": [
"!pip install -qU \"ray[rllib]\" gym"
"cell_type": "markdown",
"id": "980cea70",
"metadata": {},
"source": [
"Now we can run some imports:"
"cell_type": "code",
"execution_count": 2,
"id": "db0a45ff",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-19 13:54:16,520\tWARNING deprecation.py:47 -- DeprecationWarning: `ray.rllib.execution.buffers` has been deprecated. Use `ray.rllib.utils.replay_buffers` instead. This will raise an error in the future!\n",
"2022-05-19 13:54:16,531\tWARNING deprecation.py:47 -- DeprecationWarning: `ray.rllib.agents.marwil` has been deprecated. Use `ray.rllib.algorithms.marwil` instead. This will raise an error in the future!\n"
"source": [
"import argparse\n",
"import gym\n",
"import os\n",
"import numpy as np\n",
"import ray\n",
"from ray.air import Checkpoint\n",
"from ray.air.config import RunConfig\n",
"from ray.air.predictors.integrations.rl.rl_predictor import RLPredictor\n",
"from ray.air.train.integrations.rl.rl_trainer import RLTrainer\n",
"from ray.air.result import Result\n",
"from ray.rllib.agents.marwil import BCTrainer\n",
"from ray.tune.tuner import Tuner"
"cell_type": "markdown",
"id": "a13db7e4",
"metadata": {},
"source": [
"Here we define the training function. It will create an `RLTrainer` using the `PPO` algorithm and kick off training on the `CartPole-v0` environment:"
"cell_type": "code",
"execution_count": 3,
"id": "87fca4b1",
"metadata": {},
"outputs": [],
"source": [
"def train_rl_ppo_online(num_workers: int, use_gpu: bool = False) -> Result:\n",
" print(\"Starting online training\")\n",
" trainer = RLTrainer(\n",
" run_config=RunConfig(stop={\"training_iteration\": 5}),\n",
" scaling_config={\n",
" \"num_workers\": num_workers,\n",
" \"use_gpu\": use_gpu,\n",
" },\n",
" algorithm=\"PPO\",\n",
" config={\n",
" \"env\": \"CartPole-v0\",\n",
" \"framework\": \"tf\",\n",
" },\n",
" )\n",
" # Todo (krfricke/xwjiang): Enable checkpoint config in RunConfig\n",
" # result = trainer.fit()\n",
" tuner = Tuner(\n",
" trainer,\n",
" _tuner_kwargs={\"checkpoint_at_end\": True},\n",
" )\n",
" result = tuner.fit()[0]\n",
" return result"
"cell_type": "markdown",
"id": "f7a5d5c2",
"metadata": {},
"source": [
"Once we trained our RL policy, we want to evaluate it on a fresh environment. For this, we will also define a utility function:"
"cell_type": "code",
"execution_count": 4,
"id": "2628f3b0",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_using_checkpoint(checkpoint: Checkpoint, num_episodes) -> list:\n",
" predictor = RLPredictor.from_checkpoint(checkpoint)\n",
" env = gym.make(\"CartPole-v0\")\n",
" rewards = []\n",
" for i in range(num_episodes):\n",
" obs = env.reset()\n",
" reward = 0.0\n",
" done = False\n",
" while not done:\n",
" action = predictor.predict([obs])\n",
" obs, r, done, _ = env.step(action[0])\n",
" reward += r\n",
" rewards.append(reward)\n",
" return rewards"
"cell_type": "markdown",
"id": "d226d6aa",
"metadata": {},
"source": [
"Let's put it all together. First, we run training:"
"cell_type": "code",
"execution_count": 5,
"id": "cae1337e",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-19 13:54:16,582\tWARNING deprecation.py:47 -- DeprecationWarning: `ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG` has been deprecated. Use `ray.rllib.agents.dqn.dqn.DQNConfig(...)` instead. This will raise an error in the future!\n"
"name": "stdout",
"output_type": "stream",
"text": [
"Starting online training\n"
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-19 13:54:19,326\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://\u001b[39m\u001b[22m\n"
"data": {
"text/html": [
"== Status ==
Current time: 2022-05-19 13:54:57 (running for 00:00:35.99)
Memory usage on this node: 9.6/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.54 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/AIRPPOTrainer_2022-05-19_13-54-16
Number of trials: 1/1 (1 TERMINATED)
Trial name | status | loc | iter | total time (s) | ts | reward | episode_reward_max | episode_reward_min | episode_len_mean |
AIRPPOTrainer_cd8d6_00000 | TERMINATED | | 5 | 16.7029 | 20000 | 124.79 | 200 | 9 | 124.79 |