[rllib] Add a simple REST policy server and client example (#2232)

* wip

* cls

* re

* wip

* wip

* a3c working

* torch support

* pg works

* lint

* rm v2

* consumer id

* clean up pg

* clean up more

* fix python 2.7

* tf session management

* docs

* dqn wip

* fix compile

* dqn

* apex runs

* up

* impotrs

* ddpg

* quotes

* fix tests

* fix last r

* fix tests

* lint

* pass checkpoint restore

* kwar

* nits

* policy graph

* fix yapf

* com

* class

* pyt

* vectorization

* update

* test cpe

* unit test

* fix ddpg2

* changes

* wip

* args

* faster test

* common

* fix

* add alg option

* batch mode and policy serving

* multi serving test

* todo

* wip

* serving test

* doc async env

* num envs

* comments

* thread

* remove init hook

* update

* policy serve

* spaces

* checkpoint

* no train

* fix ppo

* comments1

* fix

* updates

* add jenkins tests

* fix

* fix pytorch

* fix

* fixes

* fix a3c policy

* fix squeeze

* fix trunc on apex

* fix squeezing for real

* update

* remove horizon test for now

* fix race condition

* update

* com

* updat

* add test

* Update run_multi_node_tests.sh

* use curl

* curl

* kill

* Update run_multi_node_tests.sh

* Update run_multi_node_tests.sh

* fix import

* update
This commit is contained in:
Eric Liang 2018-06-20 13:22:39 -07:00 committed by GitHub
parent 418cd6804a
commit e5724a9cfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 384 additions and 78 deletions

View file

@ -1,5 +0,0 @@
# flake8: noqa
from ray.rllib.examples.multiagent_mountaincar_env \
import MultiAgentMountainCarEnv
from ray.rllib.examples.multiagent_pendulum_env \
import MultiAgentPendulumEnv

View file

@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv",
entry_point=(
"ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:"
"MultiAgentMountainCarEnv"),
max_episode_steps=200,
kwargs={}
)

View file

@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv",
entry_point=(
"ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:"
"MultiAgentPendulumEnv"),
max_episode_steps=100,
kwargs={}
)

View file

@ -0,0 +1,55 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of querying a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py
"""
import argparse
import gym
from ray.rllib.utils.policy_client import PolicyClient
parser = argparse.ArgumentParser()
parser.add_argument(
"--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
"--off-policy", action="store_true",
help="Whether to take random instead of on-policy actions.")
parser.add_argument(
"--stop-at-reward", type=int, default=9999,
help="Stop once the specified reward is reached.")
if __name__ == "__main__":
args = parser.parse_args()
env = gym.make("CartPole-v0")
client = PolicyClient("http://localhost:8900")
eid = client.start_episode(training_enabled=not args.no_train)
obs = env.reset()
rewards = 0
while True:
if args.off_policy:
action = env.action_space.sample()
client.log_action(eid, obs, action)
else:
action = client.get_action(eid, obs)
obs, reward, done, info = env.step(action)
rewards += reward
client.log_returns(eid, reward, info=info)
if done:
print("Total reward:", rewards)
if rewards >= args.stop_at_reward:
print("Target reward achieved, exiting")
exit(0)
rewards = 0
client.end_episode(eid, obs)
obs = env.reset()
eid = client.start_episode(training_enabled=not args.no_train)

View file

@ -0,0 +1,66 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of running a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py
"""
import os
from gym import spaces
import ray
from ray.rllib.dqn import DQNAgent
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
SERVER_ADDRESS = "localhost"
SERVER_PORT = 8900
CHECKPOINT_FILE = "last_checkpoint.out"
class CartpoleServing(ServingEnv):
def __init__(self):
ServingEnv.__init__(
self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,)))
def run(self):
print("Starting policy server at {}:{}".format(
SERVER_ADDRESS, SERVER_PORT))
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
server.serve_forever()
if __name__ == "__main__":
ray.init()
register_env("srv", lambda _: CartpoleServing())
# We use DQN since it supports off-policy actions, but you can choose and
# configure any agent.
dqn = DQNAgent(env="srv", config={
# Use a single process to avoid needing to set up a load balancer
"num_workers": 0,
# Configure the agent to run short iterations for debugging
"exploration_fraction": 0.01,
"learning_starts": 100,
"timesteps_per_iteration": 200,
})
# Attempt to restore from checkpoint if possible.
if os.path.exists(CHECKPOINT_FILE):
checkpoint_path = open(CHECKPOINT_FILE).read()
print("Restoring from checkpoint path", checkpoint_path)
dqn.restore(checkpoint_path)
# Serving and training loop
while True:
print(pretty_print(dqn.train()))
checkpoint_path = dqn.save()
print("Last checkpoint", checkpoint_path)
with open(CHECKPOINT_FILE, "w") as f:
f.write(checkpoint_path)

View file

@ -0,0 +1,12 @@
#!/bin/bash
pkill -f cartpole_server.py
(python cartpole_server.py 2>&1 | grep -v 200) &
pid=$!
while ! curl localhost:8900; do
sleep 1
done
python cartpole_client.py --stop-at-reward=100
kill $pid

View file

@ -24,16 +24,16 @@ class SimpleServing(ServingEnv):
self.env = env
def run(self):
self.start_episode()
eid = self.start_episode()
obs = self.env.reset()
while True:
action = self.get_action(obs)
action = self.get_action(eid, obs)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
self.log_returns(eid, reward, info=info)
if done:
self.end_episode(obs)
self.end_episode(eid, obs)
obs = self.env.reset()
self.start_episode()
eid = self.start_episode()
class PartOffPolicyServing(ServingEnv):
@ -43,20 +43,20 @@ class PartOffPolicyServing(ServingEnv):
self.off_pol_frac = off_pol_frac
def run(self):
self.start_episode()
eid = self.start_episode()
obs = self.env.reset()
while True:
if random.random() < self.off_pol_frac:
action = self.env.action_space.sample()
self.log_action(obs, action)
self.log_action(eid, obs, action)
else:
action = self.get_action(obs)
action = self.get_action(eid, obs)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
self.log_returns(eid, reward, info=info)
if done:
self.end_episode(obs)
self.end_episode(eid, obs)
obs = self.env.reset()
self.start_episode()
eid = self.start_episode()
class SimpleOffPolicyServing(ServingEnv):
@ -65,18 +65,18 @@ class SimpleOffPolicyServing(ServingEnv):
self.env = env
def run(self):
self.start_episode()
eid = self.start_episode()
obs = self.env.reset()
while True:
# Take random actions
action = self.env.action_space.sample()
self.log_action(obs, action)
self.log_action(eid, obs, action)
obs, reward, done, info = self.env.step(action)
self.log_returns(reward, info=info)
self.log_returns(eid, reward, info=info)
if done:
self.end_episode(obs)
self.end_episode(eid, obs)
obs = self.env.reset()
self.start_episode()
eid = self.start_episode()
class MultiServing(ServingEnv):
@ -98,14 +98,13 @@ class MultiServing(ServingEnv):
self.start_episode(episode_id=eids[i])
cur_obs[i] = envs[i].reset()
actions = [
self.get_action(
cur_obs[i], episode_id=eids[i]) for i in active]
self.get_action(eids[i], cur_obs[i]) for i in active]
for i, action in zip(active, actions):
obs, reward, done, _ = envs[i].step(action)
cur_obs[i] = obs
self.log_returns(reward, episode_id=eids[i])
self.log_returns(eids[i], reward)
if done:
self.end_episode(obs, episode_id=eids[i])
self.end_episode(eids[i], obs)
del cur_obs[i]

View file

@ -0,0 +1,116 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
print("Couldn't import `requests` library. Be sure to install it on"
" the client side.")
class PolicyClient(object):
"""Client to interact with a RLlib policy server."""
START_EPISODE = "START_EPISODE"
GET_ACTION = "GET_ACTION"
LOG_ACTION = "LOG_ACTION"
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
def __init__(self, address):
self._address = address
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
return self._send({
"episode_id": episode_id,
"command": PolicyClient.START_EPISODE,
"training_enabled": training_enabled,
})["episode_id"]
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
Returns:
action (obj): Action from the env action space.
"""
return self._send({
"command": PolicyClient.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
})["action"]
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
"""
self._send({
"command": PolicyClient.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
})
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
"""
self._send({
"command": PolicyClient.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
})
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
self._send({
"command": PolicyClient.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
})
def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self._address, data=payload)
if response.status_code != 200:
print("Request failed", data)
print(response.text)
response.raise_for_status()
parsed = pickle.loads(response.content)
return parsed

View file

@ -0,0 +1,62 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import sys
import traceback
from ray.rllib.utils.policy_client import PolicyClient
if sys.version_info[0] == 2:
from SimpleHTTPServer import SimpleHTTPRequestHandler
from SocketServer import TCPServer as HTTPServer
from SocketServer import ThreadingMixIn
elif sys.version_info[0] == 3:
from http.server import SimpleHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
class PolicyServer(ThreadingMixIn, HTTPServer):
def __init__(self, serving_env, address, port):
handler = _make_handler(serving_env)
HTTPServer.__init__(self, (address, port), handler)
def _make_handler(serving_env):
class Handler(SimpleHTTPRequestHandler):
def do_POST(self):
content_len = int(self.headers.get('Content-Length'), 0)
raw_body = self.rfile.read(content_len)
parsed_input = pickle.loads(raw_body)
try:
response = self.execute_command(parsed_input)
self.send_response(200)
self.end_headers()
self.wfile.write(pickle.dumps(response))
except Exception:
self.send_error(500, traceback.format_exc())
def execute_command(self, args):
command = args["command"]
response = {}
if command == PolicyClient.START_EPISODE:
response["episode_id"] = serving_env.start_episode(
args["episode_id"], args["training_enabled"])
elif command == PolicyClient.GET_ACTION:
response["action"] = serving_env.get_action(
args["episode_id"], args["observation"])
elif command == PolicyClient.LOG_ACTION:
serving_env.log_action(
args["episode_id"], args["observation"], args["action"])
elif command == PolicyClient.LOG_RETURNS:
serving_env.log_returns(
args["episode_id"], args["reward"], args["info"])
elif command == PolicyClient.END_EPISODE:
serving_env.end_episode(
args["episode_id"], args["observation"])
else:
raise Exception("Unknown command: {}".format(command))
return response
return Handler

View file

@ -188,7 +188,7 @@ def _env_runner(
while True:
# Get observations from ready envs
unfiltered_obs, rewards, dones, _, off_policy_actions = \
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
async_vector_env.poll()
ready_eids = []
ready_obs = []
@ -216,24 +216,25 @@ def _env_runner(
else:
done = False
episode.batch_builder.add_values(
obs=episode.last_observation,
actions=episode.last_action_flat(),
rewards=rewards[eid],
dones=done,
new_obs=filtered_obs,
**episode.last_pi_info)
if infos[eid].get("training_enabled", True):
episode.batch_builder.add_values(
obs=episode.last_observation,
actions=episode.last_action_flat(),
rewards=rewards[eid],
dones=done,
new_obs=filtered_obs,
**episode.last_pi_info)
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if (done and not pack) or \
episode.batch_builder.count >= num_local_steps:
yield episode.batch_builder.build_and_reset(
policy.postprocess_trajectory)
elif done:
# Make sure postprocessor never goes across episode boundaries
episode.batch_builder.postprocess_batch_so_far(
policy.postprocess_trajectory)
# Cut the batch if we're not packing multiple episodes into
# one, or if we've exceeded the requested batch size.
if (done and not pack) or \
episode.batch_builder.count >= num_local_steps:
yield episode.batch_builder.build_and_reset(
policy.postprocess_trajectory)
elif done:
# Make sure postprocessor never crosses episode boundaries
episode.batch_builder.postprocess_batch_so_far(
policy.postprocess_trajectory)
if done:
# Handle episode termination

View file

@ -4,6 +4,7 @@ from __future__ import print_function
from six.moves import queue
import threading
import uuid
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
@ -26,8 +27,6 @@ class ServingEnv(threading.Thread):
This env is thread-safe, but individual episodes must be executed serially.
TODO: Provide a HTTP server/client example based on ServingEnv.
Examples:
>>> register_env("my_env", lambda config: YourServingEnv(config))
>>> agent = DQNAgent(env="my_env")
@ -51,8 +50,6 @@ class ServingEnv(threading.Thread):
self.observation_space = observation_space
self._episodes = {}
self._finished = set()
self._num_episodes = 0
self._cur_default_episode_id = None
self._results_avail_condition = threading.Condition()
self._max_concurrent_episodes = max_concurrent
@ -70,24 +67,21 @@ class ServingEnv(threading.Thread):
"""
raise NotImplementedError
def start_episode(self, episode_id=None):
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned. Auto-assignment only works if there
is at most one active episode at a time.
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
if episode_id is None:
if self._cur_default_episode_id:
raise ValueError(
"An existing episode is still active. You must pass "
"`episode_id` if there are going to be multiple active "
"episodes at once.")
episode_id = "default_{}".format(self._num_episodes)
self._cur_default_episode_id = episode_id
self._num_episodes += 1
episode_id = uuid.uuid4().hex
if episode_id in self._finished:
raise ValueError(
@ -98,14 +92,16 @@ class ServingEnv(threading.Thread):
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _Episode(
episode_id, self._results_avail_condition)
episode_id, self._results_avail_condition, training_enabled)
def get_action(self, observation, episode_id=None):
return episode_id
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
episode_id (str): Episode id passed to start_episode() or None.
Returns:
action (obj): Action from the env action space.
@ -114,19 +110,19 @@ class ServingEnv(threading.Thread):
episode = self._get(episode_id)
return episode.wait_for_action(observation)
def log_action(self, observation, action, episode_id=None):
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
episode_id (str): Episode id passed to start_episode() or None.
"""
episode = self._get(episode_id)
episode.log_action(observation, action)
def log_returns(self, reward, info=None, episode_id=None):
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
@ -134,34 +130,31 @@ class ServingEnv(threading.Thread):
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id passed to start_episode() or None.
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Optional info dict.
"""
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info
episode.cur_info = info or {}
def end_episode(self, observation, episode_id=None):
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id passed by start_episode() or None.
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
self._cur_default_episode_id = None
episode.done(observation)
def _get(self, episode_id=None):
def _get(self, episode_id):
"""Get a started episode or raise an error."""
if episode_id is None:
episode_id = self._cur_default_episode_id
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
@ -217,9 +210,10 @@ class _ServingEnvToAsync(AsyncVectorEnv):
class _Episode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition):
def __init__(self, episode_id, results_avail_condition, training_enabled):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
@ -258,6 +252,8 @@ class _Episode(object):
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
if not self.training_enabled:
item["info"]["training_enabled"] = False
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0

View file

@ -253,7 +253,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/multiagent_pendulum.py
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py