import logging import time from collections import defaultdict, namedtuple from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import tree # pip install dm_tree from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.filter import Filter from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import unbatch from ray.rllib.utils.typing import ( ActionConnectorDataType, AgentConnectorDataType, AgentID, EnvActionType, EnvID, EnvObsType, MultiAgentDict, MultiEnvDict, PolicyID, PolicyOutputType, SampleBatchType, StateBatches, TensorStructType, ) from ray.util.debug import log_once if TYPE_CHECKING: from gym.envs.classic_control.rendering import SimpleImageViewer from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.evaluation.rollout_worker import RolloutWorker logger = logging.getLogger(__name__) MIN_LARGE_BATCH_THRESHOLD = 1000 DEFAULT_LARGE_BATCH_THRESHOLD = 5000 MS_TO_SEC = 1000.0 _PolicyEvalData = namedtuple("_PolicyEvalData", ["env_id", "agent_id", "sample_batch"]) class _PerfStats: """Sampler perf stats that will be included in rollout metrics.""" def __init__(self, ema_coef: Optional[float] = None): # If not None, enable Exponential Moving Average mode. # The way we update stats is by: # updated = (1 - ema_coef) * old + ema_coef * new # In general provides more responsive stats about sampler performance. # TODO(jungong) : make ema the default (only) mode if it works well. self.ema_coef = ema_coef self.iters = 0 self.raw_obs_processing_time = 0.0 self.inference_time = 0.0 self.action_processing_time = 0.0 self.env_wait_time = 0.0 self.env_render_time = 0.0 def incr(self, field: str, value: Union[int, float]): if field == "iters": self.iters += value return # All the other fields support either global average or ema mode. if self.ema_coef is None: # Global average. self.__dict__[field] += value else: self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[ field ] + self.ema_coef * value def _get_avg(self): # Mean multiplicator (1000 = sec -> ms). factor = MS_TO_SEC / self.iters return { # Raw observation preprocessing. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor, # Computing actions through policy. "mean_inference_ms": self.inference_time * factor, # Processing actions (to be sent to env, e.g. clipping). "mean_action_processing_ms": self.action_processing_time * factor, # Waiting for environment (during poll). "mean_env_wait_ms": self.env_wait_time * factor, # Environment rendering (False by default). "mean_env_render_ms": self.env_render_time * factor, } def _get_ema(self): # In EMA mode, stats are already (exponentially) averaged, # hence we only need to do the sec -> ms conversion here. return { # Raw observation preprocessing. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC, # Computing actions through policy. "mean_inference_ms": self.inference_time * MS_TO_SEC, # Processing actions (to be sent to env, e.g. clipping). "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC, # Waiting for environment (during poll). "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC, # Environment rendering (False by default). "mean_env_render_ms": self.env_render_time * MS_TO_SEC, } def get(self): if self.ema_coef is None: return self._get_avg() else: return self._get_ema() class _NewDefaultDict(defaultdict): def __missing__(self, env_id): ret = self[env_id] = self.default_factory(env_id) return ret def _build_multi_agent_batch( episode_id: int, batch_builder: _PolicyCollectorGroup, large_batch_threshold: int, multiple_episodes_in_batch: bool, ) -> MultiAgentBatch: """Build MultiAgentBatch from a dict of _PolicyCollectors. Args: env_steps: total env steps. policy_collectors: collected training SampleBatchs by policy. Returns: Always returns a sample batch in MultiAgentBatch format. """ ma_batch = {} for pid, collector in batch_builder.policy_collectors.items(): if collector.agent_steps <= 0: continue if batch_builder.agent_steps > large_batch_threshold and log_once( "large_batch_warning" ): logger.warning( "More than {} observations in {} env steps for " "episode {} ".format( batch_builder.agent_steps, batch_builder.env_steps, episode_id ) + "are buffered in the sampler. If this is more than you " "expected, check that that you set a horizon on your " "environment correctly and that it terminates at some " "point. Note: In multi-agent environments, " "`rollout_fragment_length` sets the batch size based on " "(across-agents) environment steps, not the steps of " "individual agents, which can result in unexpectedly " "large batches." + ( "Also, you may be waiting for your Env to " "terminate (batch_mode=`complete_episodes`). Make sure " "it does at some point." if not multiple_episodes_in_batch else "" ) ) ma_batch[pid] = # Create the multi agent batch. return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps) def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch: """Batch a list of input SampleBatches into a single SampleBatch. Args: eval_data: list of SampleBatches. Returns: single batched SampleBatch. """ inference_batch = concat_samples(eval_data) if "state_in_0" in inference_batch: batch_size = len(eval_data) inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32) return inference_batch @DeveloperAPI class EnvRunnerV2: """Collect experiences from user environment using Connectors.""" def __init__( self, worker: "RolloutWorker", base_env: BaseEnv, horizon: Optional[int], multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", perf_stats: _PerfStats, soft_horizon: bool, no_done_at_end: bool, rollout_fragment_length: int = 200, count_steps_by: str = "env_steps", render: bool = None, ): """ Args: worker: Reference to the current rollout worker. base_env: Env implementing BaseEnv. horizon: Horizon of the episode. multiple_episodes_in_batch: Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks: User callbacks to run on episode events. perf_stats: Record perf stats into this object. soft_horizon: Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end: Ignore the done=True at the end of the episode and instead record done=False. rollout_fragment_length: The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. count_steps_by: One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode. render: Whether to try to render the environment after each step. """ self._worker = worker self._base_env = base_env self._multiple_episodes_in_batch = multiple_episodes_in_batch self._callbacks = callbacks self._perf_stats = perf_stats self._soft_horizon = soft_horizon self._no_done_at_end = no_done_at_end self._rollout_fragment_length = rollout_fragment_length self._count_steps_by = count_steps_by self._render = render self._horizon = self._get_horizon(horizon) # May be populated for image rendering. self._simple_image_viewer: Optional[ "SimpleImageViewer" ] = self._get_simple_image_viewer() # Keeps track of active episodes. self._active_episodes: Dict[EnvID, EpisodeV2] = _NewDefaultDict( self._new_episode ) self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict( self._new_batch_builder ) self._large_batch_threshold: int = ( max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10) if self._rollout_fragment_length != float("inf") else DEFAULT_LARGE_BATCH_THRESHOLD ) def _get_horizon(self, horizon: Optional[int]): """Try figuring out the proper horizon to use for rollout. Args: base_env: Env implementing BaseEnv. horizon: Horizon of the episode. """ # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore # error and continue with max_episode_steps=None. max_episode_steps = None try: max_episode_steps = self._base_env.get_sub_environments()[ 0 ].spec.max_episode_steps except Exception: pass # Trainer has a given `horizon` setting. if horizon: # `horizon` is larger than env's limit. if max_episode_steps and horizon > max_episode_steps: # Try to override the env's own max-step setting with our horizon. # If this won't work, throw an error. try: self._base_env.get_sub_environments()[ 0 ].spec.max_episode_steps = horizon self._base_env.get_sub_environments()[ 0 ]._max_episode_steps = horizon except Exception: raise ValueError( "Your `horizon` setting ({}) is larger than the Env's own " "timestep limit ({}), which seems to be unsettable! Try " "to increase the Env's built-in limit to be at least as " "large as your wanted `horizon`.".format( horizon, max_episode_steps ) ) # Otherwise, set Trainer's horizon to env's max-steps. elif max_episode_steps: horizon = max_episode_steps logger.debug( "No episode horizon specified, setting it to Env's limit ({}).".format( max_episode_steps ) ) # No horizon/max_episode_steps -> Episodes may be infinitely long. else: horizon = float("inf") logger.debug("No episode horizon specified, assuming inf.") return horizon def _get_simple_image_viewer(self): """Maybe construct a SimpleImageViewer instance for episode rendering.""" # Try to render the env, if required. if not self._render: return None try: from gym.envs.classic_control.rendering import SimpleImageViewer return SimpleImageViewer() except (ImportError, ModuleNotFoundError): self._render = False # disable rendering logger.warning( "Could not import gym.envs.classic_control." "rendering! Try `pip install gym[all]`." ) return None def _new_episode(self, env_id) -> EpisodeV2: """Create a new episode.""" episode = EpisodeV2( env_id, self._worker.policy_map, self._worker.policy_mapping_fn, worker=self._worker, callbacks=self._callbacks, ) # Call each policy's Exploration.on_episode_start method. # Note: This may break the exploration (e.g. ParameterNoise) of # policies in the `policy_map` that have not been recently used # (and are therefore stashed to disk). However, we certainly do not # want to loop through all (even stashed) policies here as that # would counter the purpose of the LRU policy caching. for p in self._worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start( policy=p, environment=self._base_env, episode=episode, tf_sess=p.get_session(), ) # Call on_episode_start callbacks. self._callbacks.on_episode_start( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode, env_index=env_id, ) return episode def _new_batch_builder(self, _) -> _PolicyCollectorGroup: """Create a new batch builder. We create a _PolicyCollectorGroup based on the full policy_map as the batch builder. """ return _PolicyCollectorGroup(self._worker.policy_map) def run(self) -> Iterator[SampleBatchType]: """Samples and yields training episodes continuously. Yields: Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ while True: self._perf_stats.incr("iters", 1) t0 = time.time() # Get observations from all ready agents. # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... ( unfiltered_obs, rewards, dones, infos, off_policy_actions, ) = self._base_env.poll() env_poll_time = time.time() - t0 # Process observations and prepare for policy evaluation. t1 = time.time() # types: Set[EnvID], Dict[PolicyID, List[_PolicyEvalData]], # List[Union[RolloutMetrics, SampleBatchType]] to_eval, outputs = self._process_observations( unfiltered_obs=unfiltered_obs, rewards=rewards, dones=dones, infos=infos, ) self._perf_stats.incr("raw_obs_processing_time", time.time() - t1) for o in outputs: yield o # Do batched policy eval (accross vectorized envs). t2 = time.time() # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] eval_results = self._do_policy_eval(to_eval=to_eval) self._perf_stats.incr("inference_time", time.time() - t2) # Process results and update episode state. t3 = time.time() actions_to_send: Dict[ EnvID, Dict[AgentID, EnvActionType] ] = self._process_policy_eval_results( to_eval=to_eval, eval_results=eval_results, off_policy_actions=off_policy_actions, ) self._perf_stats.incr("action_processing_time", time.time() - t3) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. t4 = time.time() self._base_env.send_actions(actions_to_send) self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4) self._maybe_render() def _get_rollout_metrics(self, episode: EpisodeV2) -> List[RolloutMetrics]: """Get rollout metrics from completed episode.""" # TODO(jungong) : why do we need to handle atari metrics differently? # Can we unify atari and normal env metrics? atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env) if atari_metrics is not None: for m in atari_metrics: m._replace(custom_metrics=episode.custom_metrics) return atari_metrics # Otherwise, return RolloutMetrics for the episode. return [ RolloutMetrics( episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data,, ) ] def _process_observations( self, unfiltered_obs: MultiEnvDict, rewards: MultiEnvDict, dones: MultiEnvDict, infos: MultiEnvDict, ) -> Tuple[ Dict[PolicyID, List[_PolicyEvalData]], List[Union[RolloutMetrics, SampleBatchType]], ]: """Process raw obs from env. Group data for active agents by policy. Reset environments that are done. Args: unfiltered_obs: obs rewards: rewards dones: dones infos: infos Returns: A tuple of: _PolicyEvalData for active agents for policy evaluation. SampleBatches and RolloutMetrics for completed agents for output. """ # Output objects. to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] # For each (vectorized) sub-environment. # types: EnvID, Dict[AgentID, EnvObsType] for env_id, env_obs in unfiltered_obs.items(): # Check for env_id having returned an error instead of a multi-agent # obs dict. This is how our BaseEnv can tell the caller to `poll()` that # one of its sub-environments is faulty and should be restarted (and the # ongoing episode should not be used for training). if isinstance(env_obs, Exception): assert dones[env_id]["__all__"] is True, ( f"ERROR: When a sub-environment (env-id {env_id}) returns an error " "as observation, the dones[__all__] flag must also be set to True!" ) # all_agents_obs is an Exception here. # Drop this episode and skip to next. self.end_episode(env_id, env_obs) # Tell the sampler we have got a faulty episode. outputs.extend(RolloutMetrics(episode_faulty=True)) continue episode: EpisodeV2 = self._active_episodes[env_id] # Episode length after this step. next_episode_length = episode.length + 1 # Check episode termination conditions. if dones[env_id]["__all__"] or next_episode_length >= self._horizon: hit_horizon = ( next_episode_length >= self._horizon and not dones[env_id]["__all__"] ) all_agents_done = True # Add rollout metrics. outputs.extend(self._get_rollout_metrics(episode)) else: hit_horizon = False all_agents_done = False # Special handling of common info dict. episode.set_last_info("__common__", infos[env_id].get("__common__", {})) # Agent sample batches grouped by policy. Each set of sample batches will # go through agent connectors together. sample_batches_by_policy = defaultdict(list) # Whether an agent is done, regardless of no_done_at_end or soft_horizon. agent_dones = {} for agent_id, obs in env_obs.items(): assert agent_id != "__all__" policy_id: PolicyID = episode.policy_for(agent_id) agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) agent_dones[agent_id] = agent_done # A completely new agent is already done -> Skip entirely. if not episode.has_init_obs(agent_id) and agent_done: continue values_dict = { SampleBatch.T: episode.length - 1, SampleBatch.ENV_ID: env_id, SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), # Last action (SampleBatch.ACTIONS) column will be populated by # StateBufferConnector. # Reward received after taking action at timestep t. SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0), # After taking action=a, did we reach terminal? SampleBatch.DONES: ( False if ( self._no_done_at_end or (hit_horizon and self._soft_horizon) ) else agent_done ), SampleBatch.INFOS: infos[env_id].get(agent_id, {}), SampleBatch.NEXT_OBS: obs, } # Queue this obs sample for connector preprocessing. sample_batches_by_policy[policy_id].append((agent_id, values_dict)) # The entire episode is done. if all_agents_done: # Let's check to see if there are any agents that haven't got the # last "done" obs yet. If there are, we have to create fake-last # observations for them. (the environment is not required to do so if # dones[__all__]=True). for agent_id in episode.get_agents(): # If the latest obs we got for this agent is done, or if its # episode state is already done, nothing to do. if agent_dones.get(agent_id, False) or episode.is_done(agent_id): continue policy_id: PolicyID = episode.policy_for(agent_id) policy = self._worker.policy_map[policy_id] # Create a fake (all-0s) observation. obs_space = policy.observation_space obs_space = getattr(obs_space, "original_space", obs_space) values_dict = { SampleBatch.T: episode.length - 1, SampleBatch.ENV_ID: env_id, SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), SampleBatch.REWARDS: 0.0, SampleBatch.DONES: True, SampleBatch.INFOS: {}, SampleBatch.NEXT_OBS: tree.map_structure( np.zeros_like, obs_space.sample() ), } # Queue these fake obs for connector preprocessing too. sample_batches_by_policy[policy_id].append((agent_id, values_dict)) # Run agent connectors. processed = [] for policy_id, batches in sample_batches_by_policy.items(): policy: Policy = self._worker.policy_map[policy_id] # Collected full MultiAgentDicts for this environment. # Run agent connectors. assert ( policy.agent_connectors ), "EnvRunnerV2 requires agent connectors to work." acd_list: List[AgentConnectorDataType] = [ AgentConnectorDataType(env_id, agent_id, data) for agent_id, data in batches ] processed.extend(policy.agent_connectors(acd_list)) for d in processed: # Record transition info if applicable. if not episode.has_init_obs(d.agent_id): episode.add_init_obs( d.agent_id,[SampleBatch.T],[SampleBatch.NEXT_OBS], ) else: episode.add_action_reward_done_next_obs( d.agent_id, ) if not all_agents_done and not agent_dones[d.agent_id]: # Add to eval set if env is not done and this particular agent # is also not done. item = _PolicyEvalData(d.env_id, d.agent_id, to_eval[policy_id].append(item) # Finished advancing episode by 1 step, mark it so. episode.step() # Exception: The very first env.poll() call causes the env to get reset # (no step taken yet, just a single starting observation logged). # We need to skip this callback in this case. if episode.length > 0: # Invoke the `on_episode_step` callback after the step is logged # to the episode. self._callbacks.on_episode_step( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode, env_index=env_id, ) # Episode is done for all agents (dones[__all__] == True) # or we hit the horizon. if all_agents_done: is_done = dones[env_id]["__all__"] # _handle_done_episode will build a MultiAgentBatch for all # the agents that are done during this step of rollout in # the case of _multiple_episodes_in_batch=False. self._handle_done_episode( env_id, env_obs, is_done, hit_horizon, to_eval, outputs ) # Try to build something. if self._multiple_episodes_in_batch: sample_batch = self._try_build_truncated_episode_multi_agent_batch( self._batch_builders[env_id], episode ) if sample_batch: outputs.append(sample_batch) # SampleBatch built from data collected by batch_builder. # Clean up and delete the batch_builder. del self._batch_builders[env_id] return to_eval, outputs def _handle_done_episode( self, env_id: EnvID, env_obs: MultiAgentDict, is_done: bool, hit_horizon: bool, to_eval: Dict[PolicyID, List[_PolicyEvalData]], outputs: List[SampleBatchType], ) -> None: """Handle an all-finished episode. Add collected SampleBatch to batch builder. Reset corresponding env, etc. Args: env_id: Environment ID. env_obs: Last per-environment observation. is_done: If all agents are done. hit_horizon: Whether the episode ended because it hit horizon. to_eval: Output container for policy eval data. outputs: Output container for collected sample batches. """ check_dones = is_done and not self._no_done_at_end episode: EpisodeV2 = self._active_episodes[env_id] batch_builder = self._batch_builders[env_id] episode.postprocess_episode( batch_builder=batch_builder, is_done=is_done or (hit_horizon and not self._soft_horizon), check_dones=check_dones, ) # If, we are not allowed to pack the next episode into the same # SampleBatch (batch_mode=complete_episodes) -> Build the # MultiAgentBatch from a single episode and add it to "outputs". # Otherwise, just postprocess and continue collecting across # episodes. if not self._multiple_episodes_in_batch: ma_sample_batch = _build_multi_agent_batch( episode.episode_id, batch_builder, self._large_batch_threshold, self._multiple_episodes_in_batch, ) if ma_sample_batch: outputs.append(ma_sample_batch) # SampleBatch built from data collected by batch_builder. # Clean up and delete the batch_builder. del self._batch_builders[env_id] # Call each (in-memory) policy's Exploration.on_episode_end # method. # Note: This may break the exploration (e.g. ParameterNoise) of # policies in the `policy_map` that have not been recently used # (and are therefore stashed to disk). However, we certainly do not # want to loop through all (even stashed) policies here as that # would counter the purpose of the LRU policy caching. for p in self._worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, environment=self._base_env, episode=episode, tf_sess=p.get_session(), ) # Call custom on_episode_end callback. self._callbacks.on_episode_end( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode, env_index=env_id, ) # Clean up and deleted the post-processed episode now that we have collected # its data. self.end_episode(env_id, episode) # Horizon hit and we have a soft horizon (no hard env reset). if hit_horizon and self._soft_horizon: resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = {env_id: env_obs} # Do not reset connector state if this is a soft reset. # Basically carry RNN and other buffered state to the # next episode from the same env. else: # TODO(jungong) : This will allow a single faulty env to # take out the entire RolloutWorker indefinitely. Revisit. while True: resetted_obs: Dict[ EnvID, Dict[AgentID, EnvObsType] ] = self._base_env.try_reset(env_id) if resetted_obs is None or not isinstance(resetted_obs, Exception): break else: # Report a faulty episode. outputs.append(RolloutMetrics(episode_faulty=True)) # Reset connector state if this is a hard reset. for p in self._worker.policy_map.cache.values(): p.agent_connectors.reset(env_id) # Reset not supported, drop this env from the ready list. if resetted_obs is None: if self._horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment." ) # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll. elif resetted_obs != ASYNC_RESET_RETURN: new_episode: EpisodeV2 = self._active_episodes[env_id] per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list) # types: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs[env_id].items(): policy_id: PolicyID = new_episode.policy_for(agent_id) per_policy_resetted_obs[policy_id].append((agent_id, raw_obs)) processed = [] for policy_id, agents_obs in per_policy_resetted_obs.items(): policy = self._worker.policy_map[policy_id] acd_list: List[AgentConnectorDataType] = [ AgentConnectorDataType( env_id, agent_id, { SampleBatch.T: new_episode.length - 1, SampleBatch.NEXT_OBS: obs, }, ) for agent_id, obs in agents_obs ] # Call agent connectors on these initial obs. processed.extend(policy.agent_connectors(acd_list)) for d in processed: # Add initial obs to buffer. new_episode.add_init_obs( d.agent_id,[SampleBatch.T],[SampleBatch.NEXT_OBS], ) item = _PolicyEvalData(d.env_id, d.agent_id, to_eval[policy_id].append(item) # Step after adding initial obs. This will give us 0 env and agent step. new_episode.step() def end_episode( self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception] ): """Clena up an episode that has finished. Args: env_id: Env ID. episode_or_exception: Instance of an episode if it finished successfully. Otherwise, the exception that was thrown, """ # Signal the end of an episode, either successfully with an Episode or # unsuccessfully with an Exception. self._callbacks.on_episode_end( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode_or_exception, env_index=env_id, ) if isinstance(episode_or_exception, EpisodeV2): episode = episode_or_exception if episode.total_agent_steps == 0: # if the key does not exist it means that throughout the episode all # observations were empty (i.e. there was no agent in the env) msg = ( f"Data from episode {episode.episode_id} does not show any agent " f"interactions. Hint: Make sure for at least one timestep in the " f"episode, env.step() returns non-empty values." ) raise ValueError(msg) # Clean up the episode and batch_builder for this env id. del self._active_episodes[env_id] def _try_build_truncated_episode_multi_agent_batch( self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2 ) -> Union[None, SampleBatch, MultiAgentBatch]: # Measure batch size in env-steps. if self._count_steps_by == "env_steps": built_steps = batch_builder.env_steps ongoing_steps = episode.active_env_steps # Measure batch-size in agent-steps. else: built_steps = batch_builder.agent_steps ongoing_steps = episode.active_agent_steps # Reached the fragment-len -> We should build an MA-Batch. if built_steps + ongoing_steps >= self._rollout_fragment_length: if self._count_steps_by != "agent_steps": assert built_steps + ongoing_steps == self._rollout_fragment_length, ( f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != " f"rollout_fragment_length ({self._rollout_fragment_length})." ) # If we reached the fragment-len only because of `episode_id` # (still ongoing) -> postprocess `episode_id` first. if built_steps < self._rollout_fragment_length: episode.postprocess_episode(batch_builder=batch_builder, is_done=False) # If builder has collected some data, # build the MA-batch and add to return values. if batch_builder.agent_steps > 0: return _build_multi_agent_batch( episode.episode_id, batch_builder, self._large_batch_threshold, self._multiple_episodes_in_batch, ) # No batch-builder: # We have reached the rollout-fragment length w/o any agent # steps! Warn that the environment may never request any # actions from any agents. elif log_once("no_agent_steps"): logger.warning( "Your environment seems to be stepping w/o ever " "emitting agent observations (agents are never " "requested to act)!" ) return None def _do_policy_eval( self, to_eval: Dict[PolicyID, List[_PolicyEvalData]], ) -> Dict[PolicyID, PolicyOutputType]: """Call compute_actions on collected episode data to get next action. Args: to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects (items in these lists will be the batch's items for the model forward pass). Returns: Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs. """ policies = self._worker.policy_map # In case policy map has changed, try to find the new policy that # should handle all these per-agent eval data. # Throws exception if these agents are mapped to multiple different # policies now. def _try_find_policy_again(eval_data: _PolicyEvalData): policy_id = None for d in eval_data: episode = self._active_episodes[d.env_id] # Force refresh policy mapping on the episode. pid = episode.policy_for(d.agent_id, refresh=True) if policy_id is not None and pid != policy_id: raise ValueError( "Policy map changed. The list of eval data that was handled " f"by a same policy is now handled by policy {pid} " "and {policy_id}. " "Please don't do this in the middle of an episode." ) policy_id = pid return _get_or_raise(self._worker.policy_map, policy_id) eval_results: Dict[PolicyID, TensorStructType] = {} for policy_id, eval_data in to_eval.items(): # In case the policyID has been removed from this worker, we need to # re-assign policy_id and re-lookup the Policy object to use. try: policy: Policy = _get_or_raise(policies, policy_id) except ValueError: # policy_mapping_fn from the worker may have already been # changed (mapping fn not staying constant within one episode). policy: Policy = _try_find_policy_again(eval_data) input_dict = _batch_inference_sample_batches( [d.sample_batch for d in eval_data] ) eval_results[policy_id] = policy.compute_actions_from_input_dict( input_dict, timestep=policy.global_timestep, episodes=[self._active_episodes[t.env_id] for t in eval_data], ) return eval_results def _process_policy_eval_results( self, to_eval: Dict[PolicyID, List[_PolicyEvalData]], eval_results: Dict[PolicyID, PolicyOutputType], off_policy_actions: MultiEnvDict, ): """Process the output of policy neural network evaluation. Records policy evaluation results into agent connectors and returns replies to send back to agents in the env. Args: to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects. eval_results: Mapping of policy IDs to list of actions, rnn-out states, extra-action-fetches dicts. off_policy_actions: Doubly keyed dict of env-ids -> agent ids -> off-policy-action, returned by a `BaseEnv.poll()` call. Returns: Nested dict of env id -> agent id -> actions to be sent to Env (np.ndarrays). """ actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict) for eval_data in to_eval.values(): for d in eval_data: actions_to_send[d.env_id] = {} # at minimum send empty dict # types: PolicyID, List[_PolicyEvalData] for policy_id, eval_data in to_eval.items(): actions: TensorStructType = eval_results[policy_id][0] actions = convert_to_numpy(actions) rnn_out: StateBatches = eval_results[policy_id][1] extra_action_out: dict = eval_results[policy_id][2] # In case actions is a list (representing the 0th dim of a batch of # primitive actions), try converting it first. if isinstance(actions, list): actions = np.array(actions) # Split action-component batches into single action rows. actions: List[EnvActionType] = unbatch(actions) policy: Policy = _get_or_raise(self._worker.policy_map, policy_id) assert ( policy.agent_connectors and policy.action_connectors ), "EnvRunnerV2 requires action connectors to work." # types: int, EnvActionType for i, action in enumerate(actions): env_id: int = eval_data[i].env_id agent_id: AgentID = eval_data[i].agent_id rnn_states: List[StateBatches] = [c[i] for c in rnn_out] fetches: Dict = {k: v[i] for k, v in extra_action_out.items()} # Post-process policy output by running them through action connectors. ac_data = ActionConnectorDataType( env_id, agent_id, (action, rnn_states, fetches) ) action_to_send, rnn_states, fetches = policy.action_connectors( ac_data ).output action_to_buffer = ( action_to_send if env_id not in off_policy_actions or agent_id not in off_policy_actions[env_id] else off_policy_actions[env_id][agent_id] ) # Notify agent connectors with this new policy output. # Necessary for state buffering agent connectors, for example. ac_data: AgentConnectorDataType = ActionConnectorDataType( env_id, agent_id, (action_to_buffer, rnn_states, fetches) ) policy.agent_connectors.on_policy_output(ac_data) assert agent_id not in actions_to_send[env_id] actions_to_send[env_id][agent_id] = action_to_send return actions_to_send def _maybe_render(self): """Visualize environment.""" # Check if we should render. if not self._render or not self._simple_image_viewer: return t5 = time.time() # Render can either return an RGB image (uint8 [w x h x 3] numpy # array) or take care of rendering itself (returning True). rendered = self._base_env.try_render() # Rendering returned an image -> Display it in a SimpleImageViewer. if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3: self._simple_image_viewer.imshow(rendered) elif rendered not in [True, False, None]: raise ValueError( f"The env's ({self._base_env}) `try_render()` method returned an" " unsupported value! Make sure you either return a " "uint8/w x h x 3 (RGB) image or handle rendering in a " "window and then return `True`." ) self._perf_stats.incr("env_render_time", time.time() - t5) def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]: """Atari games have multiple logical episodes, one per life. However, for metrics reporting we count full episodes, all lives included. """ sub_environments = base_env.get_sub_environments() if not sub_environments: return None atari_out = [] for sub_env in sub_environments: monitor = get_wrapper_by_cls(sub_env, MonitorEnv) if not monitor: return None for eps_rew, eps_len in monitor.next_episode_results(): atari_out.append(RolloutMetrics(eps_len, eps_rew)) return atari_out def _get_or_raise( mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID ) -> Union[Policy, Preprocessor, Filter]: """Returns an object under key `policy_id` in `mapping`. Args: mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The mapping dict from policy id (str) to actual object (Policy, Preprocessor, etc.). policy_id: The policy ID to lookup. Returns: Union[Policy, Preprocessor, Filter]: The found object. Raises: ValueError: If `policy_id` cannot be found in `mapping`. """ if policy_id not in mapping: raise ValueError( "Could not find policy for agent: PolicyID `{}` not found " "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys()) ) return mapping[policy_id]