2020-03-20 12:43:57 -07:00
|
|
|
"""REST client to interact with a policy server.
|
|
|
|
|
|
|
|
This client supports both local and remote policy inference modes. Local
|
|
|
|
inference is faster but causes more compute to be done on the client.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import threading
|
|
|
|
import time
|
2020-06-19 13:09:05 -07:00
|
|
|
from typing import Union, Optional
|
2022-05-15 17:25:25 +02:00
|
|
|
from enum import Enum
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
import ray.cloudpickle as pickle
|
2021-06-23 09:09:01 +02:00
|
|
|
from ray.rllib.env.external_env import ExternalEnv
|
|
|
|
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
|
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
2020-05-30 22:48:34 +02:00
|
|
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
2020-03-20 12:43:57 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2022-01-11 19:50:03 +01:00
|
|
|
from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
MultiAgentDict,
|
|
|
|
EnvInfoDict,
|
|
|
|
EnvObsType,
|
2020-06-19 13:09:05 -07:00
|
|
|
EnvActionType,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
try:
|
|
|
|
import requests # `requests` is not part of stdlib.
|
|
|
|
except ImportError:
|
|
|
|
requests = None
|
|
|
|
logger.warning(
|
|
|
|
"Couldn't import `requests` library. Be sure to install it on"
|
|
|
|
" the client side."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@PublicAPI
|
2022-05-15 17:25:25 +02:00
|
|
|
class Commands(Enum):
|
2021-06-23 09:09:01 +02:00
|
|
|
# Generic commands (for both modes).
|
|
|
|
ACTION_SPACE = "ACTION_SPACE"
|
|
|
|
OBSERVATION_SPACE = "OBSERVATION_SPACE"
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
# Commands for local inference mode.
|
|
|
|
GET_WORKER_ARGS = "GET_WORKER_ARGS"
|
|
|
|
GET_WEIGHTS = "GET_WEIGHTS"
|
|
|
|
REPORT_SAMPLES = "REPORT_SAMPLES"
|
|
|
|
|
|
|
|
# Commands for remote inference mode.
|
|
|
|
START_EPISODE = "START_EPISODE"
|
|
|
|
GET_ACTION = "GET_ACTION"
|
|
|
|
LOG_ACTION = "LOG_ACTION"
|
|
|
|
LOG_RETURNS = "LOG_RETURNS"
|
|
|
|
END_EPISODE = "END_EPISODE"
|
|
|
|
|
2022-05-15 17:25:25 +02:00
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
class PolicyClient:
|
|
|
|
"""REST client to interact with an RLlib policy server."""
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def __init__(
|
|
|
|
self, address: str, inference_mode: str = "local", update_interval: float = 10.0
|
|
|
|
):
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Create a PolicyClient instance.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
address: Server to connect to (e.g., "localhost:9090").
|
|
|
|
inference_mode: Whether to use 'local' or 'remote' policy
|
2020-03-20 12:43:57 -07:00
|
|
|
inference for computing actions.
|
2020-05-18 01:29:47 +02:00
|
|
|
update_interval (float or None): If using 'local' inference mode,
|
|
|
|
the policy is refreshed after this many seconds have passed,
|
|
|
|
or None for manual control via client.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
self.address = address
|
2021-06-23 09:09:01 +02:00
|
|
|
self.env: ExternalEnv = None
|
2020-03-20 12:43:57 -07:00
|
|
|
if inference_mode == "local":
|
|
|
|
self.local = True
|
|
|
|
self._setup_local_rollout_worker(update_interval)
|
|
|
|
elif inference_mode == "remote":
|
|
|
|
self.local = False
|
|
|
|
else:
|
|
|
|
raise ValueError("inference_mode must be either 'local' or 'remote'")
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def start_episode(
|
|
|
|
self, episode_id: Optional[str] = None, training_enabled: bool = True
|
|
|
|
) -> str:
|
2020-05-30 22:48:34 +02:00
|
|
|
"""Record the start of one or more episode(s).
|
2020-03-20 12:43:57 -07:00
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
Args:
|
|
|
|
episode_id (Optional[str]): Unique string id for the episode or
|
|
|
|
None for it to be auto-assigned.
|
2022-06-01 11:27:54 -07:00
|
|
|
training_enabled: Whether to use experiences for this
|
2020-03-20 12:43:57 -07:00
|
|
|
episode to improve the policy.
|
|
|
|
|
|
|
|
Returns:
|
2022-06-01 11:27:54 -07:00
|
|
|
episode_id: Unique string id for the episode.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
if self.local:
|
|
|
|
self._update_local_policy()
|
|
|
|
return self.env.start_episode(episode_id, training_enabled)
|
|
|
|
|
|
|
|
return self._send(
|
|
|
|
{
|
|
|
|
"episode_id": episode_id,
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.START_EPISODE,
|
2020-03-20 12:43:57 -07:00
|
|
|
"training_enabled": training_enabled,
|
|
|
|
}
|
|
|
|
)["episode_id"]
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_action(
|
|
|
|
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
|
|
|
|
) -> Union[EnvActionType, MultiAgentDict]:
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Record an observation and get the on-policy action.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
episode_id: Episode id returned from start_episode().
|
|
|
|
observation: Current environment observation.
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
Returns:
|
2022-06-01 11:27:54 -07:00
|
|
|
action: Action from the env action space.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
if self.local:
|
|
|
|
self._update_local_policy()
|
2020-05-30 22:48:34 +02:00
|
|
|
if isinstance(episode_id, (list, tuple)):
|
|
|
|
actions = {
|
|
|
|
eid: self.env.get_action(eid, observation[eid])
|
|
|
|
for eid in episode_id
|
|
|
|
}
|
|
|
|
return actions
|
|
|
|
else:
|
|
|
|
return self.env.get_action(episode_id, observation)
|
|
|
|
else:
|
|
|
|
return self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.GET_ACTION,
|
2020-05-30 22:48:34 +02:00
|
|
|
"observation": observation,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
}
|
|
|
|
)["action"]
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def log_action(
|
|
|
|
self,
|
|
|
|
episode_id: str,
|
|
|
|
observation: Union[EnvObsType, MultiAgentDict],
|
|
|
|
action: Union[EnvActionType, MultiAgentDict],
|
|
|
|
) -> None:
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Record an observation and (off-policy) action taken.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
episode_id: Episode id returned from start_episode().
|
|
|
|
observation: Current environment observation.
|
|
|
|
action: Action for the observation.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
if self.local:
|
|
|
|
self._update_local_policy()
|
|
|
|
return self.env.log_action(episode_id, observation, action)
|
|
|
|
|
|
|
|
self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.LOG_ACTION,
|
2020-03-20 12:43:57 -07:00
|
|
|
"observation": observation,
|
|
|
|
"action": action,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def log_returns(
|
|
|
|
self,
|
|
|
|
episode_id: str,
|
2022-02-06 12:35:03 +01:00
|
|
|
reward: float,
|
2020-06-19 13:09:05 -07:00
|
|
|
info: Union[EnvInfoDict, MultiAgentDict] = None,
|
|
|
|
multiagent_done_dict: Optional[MultiAgentDict] = None,
|
|
|
|
) -> None:
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Record returns from the environment.
|
|
|
|
|
|
|
|
The reward will be attributed to the previous action taken by the
|
|
|
|
episode. Rewards accumulate until the next action. If no reward is
|
|
|
|
logged before the next action, a reward of 0.0 is assumed.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-02-06 12:35:03 +01:00
|
|
|
episode_id: Episode id returned from start_episode().
|
|
|
|
reward: Reward from the environment.
|
|
|
|
info: Extra info dict.
|
|
|
|
multiagent_done_dict: Multi-agent done information.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
if self.local:
|
|
|
|
self._update_local_policy()
|
2020-05-30 22:48:34 +02:00
|
|
|
if multiagent_done_dict is not None:
|
|
|
|
assert isinstance(reward, dict)
|
2020-05-18 01:29:47 +02:00
|
|
|
return self.env.log_returns(
|
|
|
|
episode_id, reward, info, multiagent_done_dict
|
|
|
|
)
|
2020-05-30 22:48:34 +02:00
|
|
|
return self.env.log_returns(episode_id, reward, info)
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.LOG_RETURNS,
|
2020-03-20 12:43:57 -07:00
|
|
|
"reward": reward,
|
|
|
|
"info": info,
|
|
|
|
"episode_id": episode_id,
|
2020-05-18 01:29:47 +02:00
|
|
|
"done": multiagent_done_dict,
|
2020-03-20 12:43:57 -07:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def end_episode(
|
|
|
|
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
|
|
|
|
) -> None:
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Record the end of an episode.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
episode_id: Episode id returned from start_episode().
|
|
|
|
observation: Current environment observation.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
if self.local:
|
|
|
|
self._update_local_policy()
|
|
|
|
return self.env.end_episode(episode_id, observation)
|
|
|
|
|
|
|
|
self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.END_EPISODE,
|
2020-03-20 12:43:57 -07:00
|
|
|
"observation": observation,
|
|
|
|
"episode_id": episode_id,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2020-05-18 01:29:47 +02:00
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def update_policy_weights(self) -> None:
|
2020-05-18 01:29:47 +02:00
|
|
|
"""Query the server for new policy weights, if local inference is enabled."""
|
|
|
|
self._update_local_policy(force=True)
|
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
def _send(self, data):
|
|
|
|
payload = pickle.dumps(data)
|
|
|
|
response = requests.post(self.address, data=payload)
|
|
|
|
if response.status_code != 200:
|
|
|
|
logger.error("Request failed {}: {}".format(response.text, data))
|
|
|
|
response.raise_for_status()
|
|
|
|
parsed = pickle.loads(response.content)
|
|
|
|
return parsed
|
|
|
|
|
|
|
|
def _setup_local_rollout_worker(self, update_interval):
|
|
|
|
self.update_interval = update_interval
|
|
|
|
self.last_updated = 0
|
|
|
|
|
|
|
|
logger.info("Querying server for rollout worker settings.")
|
|
|
|
kwargs = self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.GET_WORKER_ARGS,
|
2020-03-20 12:43:57 -07:00
|
|
|
}
|
|
|
|
)["worker_args"]
|
2020-06-19 13:09:05 -07:00
|
|
|
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
|
2020-03-20 12:43:57 -07:00
|
|
|
kwargs, self._send
|
|
|
|
)
|
|
|
|
self.env = self.rollout_worker.env
|
|
|
|
|
2020-05-18 01:29:47 +02:00
|
|
|
def _update_local_policy(self, force=False):
|
2020-03-20 12:43:57 -07:00
|
|
|
assert self.inference_thread.is_alive()
|
2020-05-18 01:29:47 +02:00
|
|
|
if (
|
2020-03-20 12:43:57 -07:00
|
|
|
self.update_interval
|
2020-05-18 01:29:47 +02:00
|
|
|
and time.time() - self.last_updated > self.update_interval
|
|
|
|
) or force:
|
2020-03-20 12:43:57 -07:00
|
|
|
logger.info("Querying server for new policy weights.")
|
|
|
|
resp = self._send(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.GET_WEIGHTS,
|
2020-03-20 12:43:57 -07:00
|
|
|
}
|
|
|
|
)
|
|
|
|
weights = resp["weights"]
|
|
|
|
global_vars = resp["global_vars"]
|
|
|
|
logger.info(
|
|
|
|
"Updating rollout worker weights and global vars {}.".format(
|
|
|
|
global_vars
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-20 12:43:57 -07:00
|
|
|
self.rollout_worker.set_weights(weights, global_vars)
|
|
|
|
self.last_updated = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
class _LocalInferenceThread(threading.Thread):
|
|
|
|
"""Thread that handles experience generation (worker.sample() loop)."""
|
|
|
|
|
|
|
|
def __init__(self, rollout_worker, send_fn):
|
|
|
|
super().__init__()
|
|
|
|
self.daemon = True
|
|
|
|
self.rollout_worker = rollout_worker
|
|
|
|
self.send_fn = send_fn
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
logger.info("Generating new batch of experiences.")
|
|
|
|
samples = self.rollout_worker.sample()
|
|
|
|
metrics = self.rollout_worker.get_metrics()
|
2020-05-30 22:48:34 +02:00
|
|
|
if isinstance(samples, MultiAgentBatch):
|
|
|
|
logger.info(
|
|
|
|
"Sending batch of {} env steps ({} agent steps) to "
|
2020-12-09 16:41:13 +01:00
|
|
|
"server.".format(samples.env_steps(), samples.agent_steps())
|
|
|
|
)
|
2020-05-30 22:48:34 +02:00
|
|
|
else:
|
|
|
|
logger.info(
|
|
|
|
"Sending batch of {} steps back to server.".format(
|
|
|
|
samples.count
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-20 12:43:57 -07:00
|
|
|
self.send_fn(
|
|
|
|
{
|
2022-05-15 17:25:25 +02:00
|
|
|
"command": Commands.REPORT_SAMPLES,
|
2020-03-20 12:43:57 -07:00
|
|
|
"samples": samples,
|
|
|
|
"metrics": metrics,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
logger.info("Error: inference worker thread died!", e)
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _auto_wrap_external(real_env_creator):
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Wrap an environment in the ExternalEnv interface if needed.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
real_env_creator: Create an env given the env_config.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
def wrapped_creator(env_config):
|
|
|
|
real_env = real_env_creator(env_config)
|
2020-05-30 22:48:34 +02:00
|
|
|
if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
|
2020-03-20 12:43:57 -07:00
|
|
|
logger.info(
|
2020-05-30 22:48:34 +02:00
|
|
|
"The env you specified is not a supported (sub-)type of "
|
|
|
|
"ExternalEnv. Attempting to convert it automatically to "
|
|
|
|
"ExternalEnv."
|
|
|
|
)
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
if isinstance(real_env, MultiAgentEnv):
|
2020-05-18 01:29:47 +02:00
|
|
|
external_cls = ExternalMultiAgentEnv
|
2020-03-20 12:43:57 -07:00
|
|
|
else:
|
|
|
|
external_cls = ExternalEnv
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
class _ExternalEnvWrapper(external_cls):
|
2020-03-20 12:43:57 -07:00
|
|
|
def __init__(self, real_env):
|
2020-05-30 22:48:34 +02:00
|
|
|
super().__init__(
|
|
|
|
observation_space=real_env.observation_space,
|
|
|
|
action_space=real_env.action_space,
|
|
|
|
)
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
def run(self):
|
|
|
|
# Since we are calling methods on this class in the
|
|
|
|
# client, run doesn't need to do anything.
|
|
|
|
time.sleep(999999)
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
return _ExternalEnvWrapper(real_env)
|
2020-05-18 01:29:47 +02:00
|
|
|
return real_env
|
2020-03-20 12:43:57 -07:00
|
|
|
|
|
|
|
return wrapped_creator
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _create_embedded_rollout_worker(kwargs, send_fn):
|
2020-03-20 12:43:57 -07:00
|
|
|
"""Create a local rollout worker and a thread that samples from it.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
kwargs: args for the RolloutWorker constructor.
|
|
|
|
send_fn: function to send a JSON request to the server.
|
2020-03-20 12:43:57 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
# Since the server acts as an input datasource, we have to reset the
|
|
|
|
# input config to the default, which runs env rollouts.
|
|
|
|
kwargs = kwargs.copy()
|
|
|
|
del kwargs["input_creator"]
|
|
|
|
|
2021-07-28 21:25:09 +02:00
|
|
|
# Since the server also acts as an output writer, we might have to reset
|
|
|
|
# the output config to the default, i.e. "output": None, otherwise a
|
|
|
|
# local rollout worker might write to an unknown output directory
|
|
|
|
del kwargs["output_creator"]
|
|
|
|
|
2021-06-23 09:09:01 +02:00
|
|
|
# If server has no env (which is the expected case):
|
|
|
|
# Generate a dummy ExternalEnv here using RandomEnv and the
|
|
|
|
# given observation/action spaces.
|
|
|
|
if kwargs["policy_config"].get("env") is None:
|
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-06-23 09:09:01 +02:00
|
|
|
config = {
|
|
|
|
"action_space": kwargs["policy_config"]["action_space"],
|
|
|
|
"observation_space": kwargs["policy_config"]["observation_space"],
|
|
|
|
}
|
2021-08-31 22:03:23 +02:00
|
|
|
_, is_ma = check_multi_agent(kwargs["policy_config"])
|
2021-06-23 09:09:01 +02:00
|
|
|
kwargs["env_creator"] = _auto_wrap_external(
|
|
|
|
lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(config)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-06-23 09:09:01 +02:00
|
|
|
kwargs["policy_config"]["env"] = True
|
|
|
|
# Otherwise, use the env specified by the server args.
|
|
|
|
else:
|
|
|
|
real_env_creator = kwargs["env_creator"]
|
|
|
|
kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
|
|
|
|
|
|
|
|
logger.info("Creating rollout worker with kwargs={}".format(kwargs))
|
2020-08-21 12:35:16 +02:00
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
rollout_worker = RolloutWorker(**kwargs)
|
2021-06-23 09:09:01 +02:00
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
|
|
|
|
inference_thread.start()
|
2021-06-23 09:09:01 +02:00
|
|
|
|
2020-03-20 12:43:57 -07:00
|
|
|
return rollout_worker, inference_thread
|