ray/rllib/utils/policy_client.py

129 lines
3.9 KiB
Python
Raw Normal View History

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import pickle
from ray.rllib.utils.annotations import PublicAPI
logger = logging.getLogger(__name__)
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
logger.warning(
"Couldn't import `requests` library. Be sure to install it on"
" the client side.")
@PublicAPI
class PolicyClient(object):
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
"""REST client to interact with a RLlib policy server."""
START_EPISODE = "START_EPISODE"
GET_ACTION = "GET_ACTION"
LOG_ACTION = "LOG_ACTION"
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
@PublicAPI
def __init__(self, address):
self._address = address
@PublicAPI
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
return self._send({
"episode_id": episode_id,
"command": PolicyClient.START_EPISODE,
"training_enabled": training_enabled,
})["episode_id"]
@PublicAPI
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
Returns:
action (obj): Action from the env action space.
"""
return self._send({
"command": PolicyClient.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
})["action"]
@PublicAPI
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
"""
self._send({
"command": PolicyClient.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
})
@PublicAPI
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
"""
self._send({
"command": PolicyClient.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
})
@PublicAPI
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
self._send({
"command": PolicyClient.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
})
def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self._address, data=payload)
if response.status_code != 200:
logger.error("Request failed {}: {}".format(response.text, data))
response.raise_for_status()
parsed = pickle.loads(response.content)
return parsed