ray/rllib/utils/policy_client.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

124 lines
3.8 KiB
Python

import logging
import pickle
from ray.rllib.utils.annotations import PublicAPI
logger = logging.getLogger(__name__)
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
logger.warning(
"Couldn't import `requests` library. Be sure to install it on"
" the client side.")
@PublicAPI
class PolicyClient:
"""REST client to interact with a RLlib policy server."""
START_EPISODE = "START_EPISODE"
GET_ACTION = "GET_ACTION"
LOG_ACTION = "LOG_ACTION"
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
@PublicAPI
def __init__(self, address):
self._address = address
@PublicAPI
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
return self._send({
"episode_id": episode_id,
"command": PolicyClient.START_EPISODE,
"training_enabled": training_enabled,
})["episode_id"]
@PublicAPI
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
Returns:
action (obj): Action from the env action space.
"""
return self._send({
"command": PolicyClient.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
})["action"]
@PublicAPI
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
"""
self._send({
"command": PolicyClient.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
})
@PublicAPI
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
"""
self._send({
"command": PolicyClient.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
})
@PublicAPI
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
self._send({
"command": PolicyClient.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
})
def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self._address, data=payload)
if response.status_code != 200:
logger.error("Request failed {}: {}".format(response.text, data))
response.raise_for_status()
parsed = pickle.loads(response.content)
return parsed