from six.moves import queue import gym import threading import uuid from typing import Optional from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.typing import EnvActionType, EnvObsType, EnvInfoDict @PublicAPI class ExternalEnv(threading.Thread): """An environment that interfaces with external agents. Unlike simulator envs, control is inverted: The environment queries the policy to obtain actions and in return logs observations and rewards for training. This is in contrast to gym.Env, where the algorithm drives the simulation through env.step() calls. You can use ExternalEnv as the backend for policy serving (by serving HTTP requests in the run loop), for ingesting offline logs data (by reading offline transitions in the run loop), or other custom use cases not easily expressed through gym.Env. ExternalEnv supports both on-policy actions (through self.get_action()), and off-policy actions (through self.log_action()). This env is thread-safe, but individual episodes must be executed serially. Examples: >>> register_env("my_env", lambda config: YourExternalEnv(config)) >>> trainer = DQNTrainer(env="my_env") >>> while True: >>> print(trainer.train()) """ @PublicAPI def __init__(self, action_space: gym.Space, observation_space: gym.Space, max_concurrent: int = 100): """Initializes an ExternalEnv instance. Args: action_space: Action space of the env. observation_space: Observation space of the env. max_concurrent: Max number of active episodes to allow at once. Exceeding this limit raises an error. """ threading.Thread.__init__(self) self.daemon = True self.action_space = action_space self.observation_space = observation_space self._episodes = {} self._finished = set() self._results_avail_condition = threading.Condition() self._max_concurrent_episodes = max_concurrent @PublicAPI def run(self): """Override this to implement the run loop. Your loop should continuously: 1. Call self.start_episode(episode_id) 2. Call self.get_action(episode_id, obs) -or- self.log_action(episode_id, obs, action) 3. Call self.log_returns(episode_id, reward) 4. Call self.end_episode(episode_id, obs) 5. Wait if nothing to do. Multiple episodes may be started at the same time. """ raise NotImplementedError @PublicAPI def start_episode(self, episode_id: Optional[str] = None, training_enabled: bool = True) -> str: """Record the start of an episode. Args: episode_id: Unique string id for the episode or None for it to be auto-assigned and returned. training_enabled: Whether to use experiences for this episode to improve the policy. Returns: Unique string id for the episode. """ if episode_id is None: episode_id = uuid.uuid4().hex if episode_id in self._finished: raise ValueError( "Episode {} has already completed.".format(episode_id)) if episode_id in self._episodes: raise ValueError( "Episode {} is already started".format(episode_id)) self._episodes[episode_id] = _ExternalEnvEpisode( episode_id, self._results_avail_condition, training_enabled) return episode_id @PublicAPI def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType: """Record an observation and get the on-policy action. Args: episode_id: Episode id returned from start_episode(). observation: Current environment observation. Returns: Action from the env action space. """ episode = self._get(episode_id) return episode.wait_for_action(observation) @PublicAPI def log_action(self, episode_id: str, observation: EnvObsType, action: EnvActionType) -> None: """Record an observation and (off-policy) action taken. Args: episode_id: Episode id returned from start_episode(). observation: Current environment observation. action: Action for the observation. """ episode = self._get(episode_id) episode.log_action(observation, action) @PublicAPI def log_returns(self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None) -> None: """Records returns (rewards and infos) from the environment. The reward will be attributed to the previous action taken by the episode. Rewards accumulate until the next action. If no reward is logged before the next action, a reward of 0.0 is assumed. Args: episode_id: Episode id returned from start_episode(). reward: Reward from the environment. info: Optional info dict. """ episode = self._get(episode_id) episode.cur_reward += reward if info: episode.cur_info = info or {} @PublicAPI def end_episode(self, episode_id: str, observation: EnvObsType) -> None: """Records the end of an episode. Args: episode_id: Episode id returned from start_episode(). observation: Current environment observation. """ episode = self._get(episode_id) self._finished.add(episode.episode_id) episode.done(observation) def _get(self, episode_id: str) -> "_ExternalEnvEpisode": """Get a started episode by its ID or raise an error.""" if episode_id in self._finished: raise ValueError( "Episode {} has already completed.".format(episode_id)) if episode_id not in self._episodes: raise ValueError("Episode {} not found.".format(episode_id)) return self._episodes[episode_id] class _ExternalEnvEpisode: """Tracked state for each active episode.""" def __init__(self, episode_id: str, results_avail_condition: threading.Condition, training_enabled: bool, multiagent: bool = False): self.episode_id = episode_id self.results_avail_condition = results_avail_condition self.training_enabled = training_enabled self.multiagent = multiagent self.data_queue = queue.Queue() self.action_queue = queue.Queue() if multiagent: self.new_observation_dict = None self.new_action_dict = None self.cur_reward_dict = {} self.cur_done_dict = {"__all__": False} self.cur_info_dict = {} else: self.new_observation = None self.new_action = None self.cur_reward = 0.0 self.cur_done = False self.cur_info = {} def get_data(self): if self.data_queue.empty(): return None return self.data_queue.get_nowait() def log_action(self, observation, action): if self.multiagent: self.new_observation_dict = observation self.new_action_dict = action else: self.new_observation = observation self.new_action = action self._send() self.action_queue.get(True, timeout=60.0) def wait_for_action(self, observation): if self.multiagent: self.new_observation_dict = observation else: self.new_observation = observation self._send() return self.action_queue.get(True, timeout=300.0) def done(self, observation): if self.multiagent: self.new_observation_dict = observation self.cur_done_dict = {"__all__": True} else: self.new_observation = observation self.cur_done = True self._send() def _send(self): if self.multiagent: if not self.training_enabled: for agent_id in self.cur_info_dict: self.cur_info_dict[agent_id]["training_enabled"] = False item = { "obs": self.new_observation_dict, "reward": self.cur_reward_dict, "done": self.cur_done_dict, "info": self.cur_info_dict, } if self.new_action_dict is not None: item["off_policy_action"] = self.new_action_dict self.new_observation_dict = None self.new_action_dict = None self.cur_reward_dict = {} else: item = { "obs": self.new_observation, "reward": self.cur_reward, "done": self.cur_done, "info": self.cur_info, } if self.new_action is not None: item["off_policy_action"] = self.new_action self.new_observation = None self.new_action = None self.cur_reward = 0.0 if not self.training_enabled: item["info"]["training_enabled"] = False with self.results_avail_condition: self.data_queue.put_nowait(item) self.results_avail_condition.notify()