2020-03-20 12:43:57 -07:00
|
|
|
"""DEPRECATED: Please use rllib.env.PolicyClient instead."""
|
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-06-20 13:22:39 -07:00
|
|
|
import pickle
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2020-03-20 12:43:57 -07:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-06-20 13:22:39 -07:00
|
|
|
try:
|
|
|
|
import requests # `requests` is not part of stdlib.
|
|
|
|
except ImportError:
|
|
|
|
requests = None
|
2018-12-18 17:04:51 -08:00
|
|
|
logger.warning(
|
|
|
|
"Couldn't import `requests` library. Be sure to install it on"
|
|
|
|
" the client side.")
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class PolicyClient:
|
2020-03-20 12:43:57 -07:00
|
|
|
"""DEPRECATED: Please use rllib.env.PolicyClient instead."""
|
2018-06-20 13:22:39 -07:00
|
|
|
|
|
|
|
START_EPISODE = "START_EPISODE"
|
|
|
|
GET_ACTION = "GET_ACTION"
|
|
|
|
LOG_ACTION = "LOG_ACTION"
|
|
|
|
LOG_RETURNS = "LOG_RETURNS"
|
|
|
|
END_EPISODE = "END_EPISODE"
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def __init__(self, address):
|
2020-03-20 12:43:57 -07:00
|
|
|
deprecation_warning(
|
|
|
|
"rllib.utils.PolicyServer", new="rllib.env.PolicyServerInput")
|
2018-06-20 13:22:39 -07:00
|
|
|
self._address = address
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def start_episode(self, episode_id=None, training_enabled=True):
|
|
|
|
"""Record the start of an episode.
|
|
|
|
|
|
|
|
Arguments:
|
2020-05-30 22:48:34 +02:00
|
|
|
episode_id (Optional[str]): Unique string id for the episode or
|
|
|
|
None for it to be auto-assigned.
|
2018-06-20 13:22:39 -07:00
|
|
|
training_enabled (bool): Whether to use experiences for this
|
|
|
|
episode to improve the policy.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
episode_id (str): Unique string id for the episode.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self._send({
|
|
|
|
"episode_id": episode_id,
|
|
|
|
"command": PolicyClient.START_EPISODE,
|
|
|
|
"training_enabled": training_enabled,
|
|
|
|
})["episode_id"]
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def get_action(self, episode_id, observation):
|
|
|
|
"""Record an observation and get the on-policy action.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
episode_id (str): Episode id returned from start_episode().
|
|
|
|
observation (obj): Current environment observation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
action (obj): Action from the env action space.
|
|
|
|
"""
|
|
|
|
return self._send({
|
|
|
|
"command": PolicyClient.GET_ACTION,
|
|
|
|
"observation": observation,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
})["action"]
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def log_action(self, episode_id, observation, action):
|
|
|
|
"""Record an observation and (off-policy) action taken.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
episode_id (str): Episode id returned from start_episode().
|
|
|
|
observation (obj): Current environment observation.
|
|
|
|
action (obj): Action for the observation.
|
|
|
|
"""
|
|
|
|
self._send({
|
|
|
|
"command": PolicyClient.LOG_ACTION,
|
|
|
|
"observation": observation,
|
|
|
|
"action": action,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
})
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def log_returns(self, episode_id, reward, info=None):
|
|
|
|
"""Record returns from the environment.
|
|
|
|
|
|
|
|
The reward will be attributed to the previous action taken by the
|
|
|
|
episode. Rewards accumulate until the next action. If no reward is
|
|
|
|
logged before the next action, a reward of 0.0 is assumed.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
episode_id (str): Episode id returned from start_episode().
|
|
|
|
reward (float): Reward from the environment.
|
|
|
|
"""
|
|
|
|
self._send({
|
|
|
|
"command": PolicyClient.LOG_RETURNS,
|
|
|
|
"reward": reward,
|
|
|
|
"info": info,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
})
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-06-20 13:22:39 -07:00
|
|
|
def end_episode(self, episode_id, observation):
|
|
|
|
"""Record the end of an episode.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
episode_id (str): Episode id returned from start_episode().
|
|
|
|
observation (obj): Current environment observation.
|
|
|
|
"""
|
|
|
|
self._send({
|
|
|
|
"command": PolicyClient.END_EPISODE,
|
|
|
|
"observation": observation,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
})
|
|
|
|
|
|
|
|
def _send(self, data):
|
|
|
|
payload = pickle.dumps(data)
|
|
|
|
response = requests.post(self._address, data=payload)
|
|
|
|
if response.status_code != 200:
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.error("Request failed {}: {}".format(response.text, data))
|
2018-06-20 13:22:39 -07:00
|
|
|
response.raise_for_status()
|
|
|
|
parsed = pickle.loads(response.content)
|
|
|
|
return parsed
|