ray/rllib/evaluation/env_runner_v2.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1071 lines
44 KiB
Python
Raw Normal View History

import logging
import time
from collections import defaultdict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import tree # pip install dm_tree
from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import unbatch
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentID,
EnvActionType,
EnvID,
EnvObsType,
MultiAgentDict,
MultiEnvDict,
PolicyID,
PolicyOutputType,
SampleBatchType,
StateBatches,
TensorStructType,
)
from ray.util.debug import log_once
if TYPE_CHECKING:
from gym.envs.classic_control.rendering import SimpleImageViewer
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__)
MIN_LARGE_BATCH_THRESHOLD = 1000
DEFAULT_LARGE_BATCH_THRESHOLD = 5000
_PolicyEvalData = namedtuple("_PolicyEvalData", ["env_id", "agent_id", "sample_batch"])
class _PerfStats:
"""Sampler perf stats that will be included in rollout metrics."""
def __init__(self):
self.iters = 0
self.raw_obs_processing_time = 0.0
self.inference_time = 0.0
self.action_processing_time = 0.0
self.env_wait_time = 0.0
self.env_render_time = 0.0
def get(self):
# Mean multiplicator (1000 = ms -> sec).
factor = 1000 / self.iters
return {
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * factor,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * factor,
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * factor,
# Environment rendering (False by default).
"mean_env_render_ms": self.env_render_time * factor,
}
class _NewDefaultDict(defaultdict):
def __missing__(self, env_id):
ret = self[env_id] = self.default_factory(env_id)
return ret
def _build_multi_agent_batch(
episode_id: int,
batch_builder: _PolicyCollectorGroup,
large_batch_threshold: int,
multiple_episodes_in_batch: bool,
) -> MultiAgentBatch:
"""Build MultiAgentBatch from a dict of _PolicyCollectors.
Args:
env_steps: total env steps.
policy_collectors: collected training SampleBatchs by policy.
Returns:
Always returns a sample batch in MultiAgentBatch format.
"""
ma_batch = {}
for pid, collector in batch_builder.policy_collectors.items():
if collector.agent_steps <= 0:
continue
if batch_builder.agent_steps > large_batch_threshold and log_once(
"large_batch_warning"
):
logger.warning(
"More than {} observations in {} env steps for "
"episode {} ".format(
batch_builder.agent_steps, batch_builder.env_steps, episode_id
)
+ "are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some "
"point. Note: In multi-agent environments, "
"`rollout_fragment_length` sets the batch size based on "
"(across-agents) environment steps, not the steps of "
"individual agents, which can result in unexpectedly "
"large batches."
+ (
"Also, you may be waiting for your Env to "
"terminate (batch_mode=`complete_episodes`). Make sure "
"it does at some point."
if not multiple_episodes_in_batch
else ""
)
)
ma_batch[pid] = collector.build()
# Create the multi agent batch.
return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
"""Batch a list of input SampleBatches into a single SampleBatch.
Args:
eval_data: list of SampleBatches.
Returns:
single batched SampleBatch.
"""
inference_batch = concat_samples(eval_data)
if "state_in_0" in inference_batch:
batch_size = len(eval_data)
inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
return inference_batch
@DeveloperAPI
class EnvRunnerV2:
"""Collect experiences from user environment using Connectors."""
def __init__(
self,
worker: "RolloutWorker",
base_env: BaseEnv,
horizon: Optional[int],
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
perf_stats: _PerfStats,
soft_horizon: bool,
no_done_at_end: bool,
rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps",
render: bool = None,
):
"""
Args:
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
horizon: Horizon of the episode.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks: User callbacks to run on episode events.
perf_stats: Record perf stats into this object.
soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end: Ignore the done=True at the end of the episode
and instead record done=False.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
render: Whether to try to render the environment after each
step.
"""
self._worker = worker
self._base_env = base_env
self._multiple_episodes_in_batch = multiple_episodes_in_batch
self._callbacks = callbacks
self._perf_stats = perf_stats
self._soft_horizon = soft_horizon
self._no_done_at_end = no_done_at_end
self._rollout_fragment_length = rollout_fragment_length
self._count_steps_by = count_steps_by
self._render = render
self._horizon = self._get_horizon(horizon)
# May be populated for image rendering.
self._simple_image_viewer: Optional[
"SimpleImageViewer"
] = self._get_simple_image_viewer()
# Keeps track of active episodes.
self._active_episodes: Dict[EnvID, EpisodeV2] = _NewDefaultDict(
self._new_episode
)
self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
self._new_batch_builder
)
self._large_batch_threshold: int = (
max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
if self._rollout_fragment_length != float("inf")
else DEFAULT_LARGE_BATCH_THRESHOLD
)
def _get_horizon(self, horizon: Optional[int]):
"""Try figuring out the proper horizon to use for rollout.
Args:
base_env: Env implementing BaseEnv.
horizon: Horizon of the episode.
"""
# Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
# error and continue with max_episode_steps=None.
max_episode_steps = None
try:
max_episode_steps = self._base_env.get_sub_environments()[
0
].spec.max_episode_steps
except Exception:
pass
# Trainer has a given `horizon` setting.
if horizon:
# `horizon` is larger than env's limit.
if max_episode_steps and horizon > max_episode_steps:
# Try to override the env's own max-step setting with our horizon.
# If this won't work, throw an error.
try:
self._base_env.get_sub_environments()[
0
].spec.max_episode_steps = horizon
self._base_env.get_sub_environments()[
0
]._max_episode_steps = horizon
except Exception:
raise ValueError(
"Your `horizon` setting ({}) is larger than the Env's own "
"timestep limit ({}), which seems to be unsettable! Try "
"to increase the Env's built-in limit to be at least as "
"large as your wanted `horizon`.".format(
horizon, max_episode_steps
)
)
# Otherwise, set Trainer's horizon to env's max-steps.
elif max_episode_steps:
horizon = max_episode_steps
logger.debug(
"No episode horizon specified, setting it to Env's limit ({}).".format(
max_episode_steps
)
)
# No horizon/max_episode_steps -> Episodes may be infinitely long.
else:
horizon = float("inf")
logger.debug("No episode horizon specified, assuming inf.")
return horizon
def _get_simple_image_viewer(self):
"""Maybe construct a SimpleImageViewer instance for episode rendering."""
# Try to render the env, if required.
if not self._render:
return None
try:
from gym.envs.classic_control.rendering import SimpleImageViewer
return SimpleImageViewer()
except (ImportError, ModuleNotFoundError):
self._render = False # disable rendering
logger.warning(
"Could not import gym.envs.classic_control."
"rendering! Try `pip install gym[all]`."
)
return None
def _new_episode(self, env_id) -> EpisodeV2:
"""Create a new episode."""
episode = EpisodeV2(
env_id,
self._worker.policy_map,
self._worker.policy_mapping_fn,
worker=self._worker,
callbacks=self._callbacks,
)
# Call each policy's Exploration.on_episode_start method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
environment=self._base_env,
episode=episode,
tf_sess=p.get_session(),
)
# Call on_episode_start callbacks.
self._callbacks.on_episode_start(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode,
env_index=env_id,
)
return episode
def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
"""Create a new batch builder.
We create a _PolicyCollectorGroup based on the full policy_map
as the batch builder.
"""
return _PolicyCollectorGroup(self._worker.policy_map)
def run(self) -> Iterator[SampleBatchType]:
"""Samples and yields training episodes continuously.
Yields:
Object containing state, action, reward, terminal condition,
and other fields as dictated by `policy`.
"""
while True:
self._perf_stats.iters += 1
t0 = time.time()
# Get observations from all ready agents.
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
(
unfiltered_obs,
rewards,
dones,
infos,
off_policy_actions,
) = self._base_env.poll()
self._perf_stats.env_wait_time += time.time() - t0
# Process observations and prepare for policy evaluation.
t1 = time.time()
# types: Set[EnvID], Dict[PolicyID, List[_PolicyEvalData]],
# List[Union[RolloutMetrics, SampleBatchType]]
to_eval, outputs = self._process_observations(
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
)
self._perf_stats.raw_obs_processing_time += time.time() - t1
for o in outputs:
yield o
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = self._do_policy_eval(to_eval=to_eval)
self._perf_stats.inference_time += time.time() - t2
# Process results and update episode state.
t3 = time.time()
actions_to_send: Dict[
EnvID, Dict[AgentID, EnvActionType]
] = self._process_policy_eval_results(
to_eval=to_eval,
eval_results=eval_results,
off_policy_actions=off_policy_actions,
)
self._perf_stats.action_processing_time += time.time() - t3
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()
self._base_env.send_actions(actions_to_send)
self._perf_stats.env_wait_time += time.time() - t4
self._maybe_render()
def _get_rollout_metrics(self, episode: EpisodeV2) -> List[RolloutMetrics]:
"""Get rollout metrics from completed episode."""
# TODO(jungong) : why do we need to handle atari metrics differently?
# Can we unify atari and normal env metrics?
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
if atari_metrics is not None:
for m in atari_metrics:
m._replace(custom_metrics=episode.custom_metrics)
return atari_metrics
# Otherwise, return RolloutMetrics for the episode.
return [
RolloutMetrics(
episode.length,
episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics,
{},
episode.hist_data,
episode.media,
)
]
def _process_observations(
self,
unfiltered_obs: MultiEnvDict,
rewards: MultiEnvDict,
dones: MultiEnvDict,
infos: MultiEnvDict,
) -> Tuple[
Dict[PolicyID, List[_PolicyEvalData]],
List[Union[RolloutMetrics, SampleBatchType]],
]:
"""Process raw obs from env.
Group data for active agents by policy. Reset environments that are done.
Args:
unfiltered_obs: obs
rewards: rewards
dones: dones
infos: infos
Returns:
A tuple of:
_PolicyEvalData for active agents for policy evaluation.
SampleBatches and RolloutMetrics for completed agents for output.
"""
# Output objects.
to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list)
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
# For each (vectorized) sub-environment.
# types: EnvID, Dict[AgentID, EnvObsType]
for env_id, env_obs in unfiltered_obs.items():
# Check for env_id having returned an error instead of a multi-agent
# obs dict. This is how our BaseEnv can tell the caller to `poll()` that
# one of its sub-environments is faulty and should be restarted (and the
# ongoing episode should not be used for training).
if isinstance(env_obs, Exception):
assert dones[env_id]["__all__"] is True, (
f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
"as observation, the dones[__all__] flag must also be set to True!"
)
# all_agents_obs is an Exception here.
# Drop this episode and skip to next.
self.end_episode(env_id, env_obs)
continue
episode: EpisodeV2 = self._active_episodes[env_id]
# Episode length after this step.
next_episode_length = episode.length + 1
# Check episode termination conditions.
if dones[env_id]["__all__"] or next_episode_length >= self._horizon:
hit_horizon = (
next_episode_length >= self._horizon
and not dones[env_id]["__all__"]
)
all_agents_done = True
# Add rollout metrics.
outputs.extend(self._get_rollout_metrics(episode))
else:
hit_horizon = False
all_agents_done = False
# Special handling of common info dict.
episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
# Agent sample batches grouped by policy. Each set of sample batches will
# go through agent connectors together.
sample_batches_by_policy = defaultdict(list)
# Whether an agent is done, regardless of no_done_at_end or soft_horizon.
agent_dones = {}
for agent_id, obs in env_obs.items():
assert agent_id != "__all__"
policy_id: PolicyID = episode.policy_for(agent_id)
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
agent_dones[agent_id] = agent_done
# A completely new agent is already done -> Skip entirely.
if not episode.has_init_obs(agent_id) and agent_done:
continue
values_dict = {
SampleBatch.T: episode.length - 1,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
# Last action (SampleBatch.ACTIONS) column will be populated by
# StateBufferConnector.
# Reward received after taking action at timestep t.
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we reach terminal?
SampleBatch.DONES: (
False
if (
self._no_done_at_end or (hit_horizon and self._soft_horizon)
)
else agent_done
),
SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
SampleBatch.NEXT_OBS: obs,
}
# Queue this obs sample for connector preprocessing.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# The entire episode is done.
if all_agents_done:
# Let's check to see if there are any agents that haven't got the
# last "done" obs yet. If there are, we have to create fake-last
# observations for them. (the environment is not required to do so if
# dones[__all__]=True).
for agent_id in episode.get_agents():
# If the latest obs we got for this agent is done, or if its
# episode state is already done, nothing to do.
if agent_dones.get(agent_id, False) or episode.is_done(agent_id):
continue
policy_id: PolicyID = episode.policy_for(agent_id)
policy = self._worker.policy_map[policy_id]
# Create a fake (all-0s) observation.
obs_space = policy.observation_space
obs_space = getattr(obs_space, "original_space", obs_space)
values_dict = {
SampleBatch.T: episode.length - 1,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
SampleBatch.REWARDS: 0.0,
SampleBatch.DONES: True,
SampleBatch.INFOS: {},
SampleBatch.NEXT_OBS: tree.map_structure(
np.zeros_like, obs_space.sample()
),
}
# Queue these fake obs for connector preprocessing too.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# Run agent connectors.
processed = []
for policy_id, batches in sample_batches_by_policy.items():
policy: Policy = self._worker.policy_map[policy_id]
# Collected full MultiAgentDicts for this environment.
# Run agent connectors.
assert (
policy.agent_connectors
), "EnvRunnerV2 requires agent connectors to work."
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(env_id, agent_id, data)
for agent_id, data in batches
]
processed.extend(policy.agent_connectors(acd_list))
for d in processed:
# Record transition info if applicable.
if not episode.has_init_obs(d.agent_id):
episode.add_init_obs(
d.agent_id,
d.data.for_training[SampleBatch.T],
d.data.for_training[SampleBatch.NEXT_OBS],
)
else:
episode.add_action_reward_done_next_obs(
d.agent_id, d.data.for_training
)
if not all_agents_done and not agent_dones[d.agent_id]:
# Add to eval set if env is not done and this particular agent
# is also not done.
item = _PolicyEvalData(d.env_id, d.agent_id, d.data.for_action)
to_eval[policy_id].append(item)
# Finished advancing episode by 1 step, mark it so.
episode.step()
# Exception: The very first env.poll() call causes the env to get reset
# (no step taken yet, just a single starting observation logged).
# We need to skip this callback in this case.
if episode.length > 0:
# Invoke the `on_episode_step` callback after the step is logged
# to the episode.
self._callbacks.on_episode_step(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode,
env_index=env_id,
)
# Episode is done for all agents (dones[__all__] == True)
# or we hit the horizon.
if all_agents_done:
is_done = dones[env_id]["__all__"]
# _handle_done_episode will build a MultiAgentBatch for all
# the agents that are done during this step of rollout in
# the case of _multiple_episodes_in_batch=False.
self._handle_done_episode(
env_id, env_obs, is_done, hit_horizon, to_eval, outputs
)
# Try to build something.
if self._multiple_episodes_in_batch:
sample_batch = self._try_build_truncated_episode_multi_agent_batch(
self._batch_builders[env_id], episode
)
if sample_batch:
outputs.append(sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
return to_eval, outputs
def _handle_done_episode(
self,
env_id: EnvID,
env_obs: MultiAgentDict,
is_done: bool,
hit_horizon: bool,
to_eval: Dict[PolicyID, List[_PolicyEvalData]],
outputs: List[SampleBatchType],
) -> None:
"""Handle an all-finished episode.
Add collected SampleBatch to batch builder. Reset corresponding env, etc.
Args:
env_id: Environment ID.
env_obs: Last per-environment observation.
is_done: If all agents are done.
hit_horizon: Whether the episode ended because it hit horizon.
to_eval: Output container for policy eval data.
outputs: Output container for collected sample batches.
"""
check_dones = is_done and not self._no_done_at_end
episode: EpisodeV2 = self._active_episodes[env_id]
batch_builder = self._batch_builders[env_id]
episode.postprocess_episode(
batch_builder=batch_builder,
is_done=is_done or (hit_horizon and not self._soft_horizon),
check_dones=check_dones,
)
# If, we are not allowed to pack the next episode into the same
# SampleBatch (batch_mode=complete_episodes) -> Build the
# MultiAgentBatch from a single episode and add it to "outputs".
# Otherwise, just postprocess and continue collecting across
# episodes.
if not self._multiple_episodes_in_batch:
ma_sample_batch = _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
if ma_sample_batch:
outputs.append(ma_sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=self._base_env,
episode=episode,
tf_sess=p.get_session(),
)
# Call custom on_episode_end callback.
self._callbacks.on_episode_end(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode,
env_index=env_id,
)
# Clean up and deleted the post-processed episode now that we have collected
# its data.
self.end_episode(env_id, episode)
# Horizon hit and we have a soft horizon (no hard env reset).
if hit_horizon and self._soft_horizon:
resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = {env_id: env_obs}
# Do not reset connector state if this is a soft reset.
# Basically carry RNN and other buffered state to the
# next episode from the same env.
else:
resetted_obs: Dict[
EnvID, Dict[AgentID, EnvObsType]
] = self._base_env.try_reset(env_id)
# Reset connector state if this is a hard reset.
for p in self._worker.policy_map.cache.values():
p.agent_connectors.reset(env_id)
# Reset not supported, drop this env from the ready list.
if resetted_obs is None:
if self._horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment."
)
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll.
elif resetted_obs != ASYNC_RESET_RETURN:
new_episode: EpisodeV2 = self._active_episodes[env_id]
per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
# types: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs[env_id].items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
processed = []
for policy_id, agents_obs in per_policy_resetted_obs.items():
policy = self._worker.policy_map[policy_id]
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.T: new_episode.length - 1,
SampleBatch.NEXT_OBS: obs,
},
)
for agent_id, obs in agents_obs
]
# Call agent connectors on these initial obs.
processed.extend(policy.agent_connectors(acd_list))
for d in processed:
# Add initial obs to buffer.
new_episode.add_init_obs(
d.agent_id,
d.data.for_training[SampleBatch.T],
d.data.for_training[SampleBatch.NEXT_OBS],
)
item = _PolicyEvalData(d.env_id, d.agent_id, d.data.for_action)
to_eval[policy_id].append(item)
# Step after adding initial obs. This will give us 0 env and agent step.
new_episode.step()
def end_episode(
self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
):
"""Clena up an episode that has finished.
Args:
env_id: Env ID.
episode_or_exception: Instance of an episode if it finished successfully.
Otherwise, the exception that was thrown,
"""
# Signal the end of an episode, either successfully with an Episode or
# unsuccessfully with an Exception.
self._callbacks.on_episode_end(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode_or_exception,
env_index=env_id,
)
if isinstance(episode_or_exception, EpisodeV2):
episode = episode_or_exception
if episode.total_agent_steps == 0:
# if the key does not exist it means that throughout the episode all
# observations were empty (i.e. there was no agent in the env)
msg = (
f"Data from episode {episode.episode_id} does not show any agent "
f"interactions. Hint: Make sure for at least one timestep in the "
f"episode, env.step() returns non-empty values."
)
raise ValueError(msg)
# Clean up the episode and batch_builder for this env id.
del self._active_episodes[env_id]
def _try_build_truncated_episode_multi_agent_batch(
self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
) -> Union[None, SampleBatch, MultiAgentBatch]:
# Measure batch size in env-steps.
if self._count_steps_by == "env_steps":
built_steps = batch_builder.env_steps
ongoing_steps = episode.active_env_steps
# Measure batch-size in agent-steps.
else:
built_steps = batch_builder.agent_steps
ongoing_steps = episode.active_agent_steps
# Reached the fragment-len -> We should build an MA-Batch.
if built_steps + ongoing_steps >= self._rollout_fragment_length:
if self._count_steps_by != "agent_steps":
assert built_steps + ongoing_steps == self._rollout_fragment_length, (
f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
f"rollout_fragment_length ({self._rollout_fragment_length})."
)
# If we reached the fragment-len only because of `episode_id`
# (still ongoing) -> postprocess `episode_id` first.
if built_steps < self._rollout_fragment_length:
episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
# If builder has collected some data,
# build the MA-batch and add to return values.
if batch_builder.agent_steps > 0:
return _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
# No batch-builder:
# We have reached the rollout-fragment length w/o any agent
# steps! Warn that the environment may never request any
# actions from any agents.
elif log_once("no_agent_steps"):
logger.warning(
"Your environment seems to be stepping w/o ever "
"emitting agent observations (agents are never "
"requested to act)!"
)
return None
def _do_policy_eval(
self,
to_eval: Dict[PolicyID, List[_PolicyEvalData]],
) -> Dict[PolicyID, PolicyOutputType]:
"""Call compute_actions on collected episode data to get next action.
Args:
to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects
(items in these lists will be the batch's items for the model
forward pass).
Returns:
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
"""
policies = self._worker.policy_map
# In case policy map has changed, try to find the new policy that
# should handle all these per-agent eval data.
# Throws exception if these agents are mapped to multiple different
# policies now.
def _try_find_policy_again(eval_data: _PolicyEvalData):
policy_id = None
for d in eval_data:
episode = self._active_episodes[d.env_id]
# Force refresh policy mapping on the episode.
pid = episode.policy_for(d.agent_id, refresh=True)
if policy_id is not None and pid != policy_id:
raise ValueError(
"Policy map changed. The list of eval data that was handled "
f"by a same policy is now handled by policy {pid} "
"and {policy_id}. "
"Please don't do this in the middle of an episode."
)
policy_id = pid
return _get_or_raise(self._worker.policy_map, policy_id)
eval_results: Dict[PolicyID, TensorStructType] = {}
for policy_id, eval_data in to_eval.items():
# In case the policyID has been removed from this worker, we need to
# re-assign policy_id and re-lookup the Policy object to use.
try:
policy: Policy = _get_or_raise(policies, policy_id)
except ValueError:
# policy_mapping_fn from the worker may have already been
# changed (mapping fn not staying constant within one episode).
policy: Policy = _try_find_policy_again(eval_data)
input_dict = _batch_inference_sample_batches(
[d.sample_batch for d in eval_data]
)
eval_results[policy_id] = policy.compute_actions_from_input_dict(
input_dict,
timestep=policy.global_timestep,
episodes=[self._active_episodes[t.env_id] for t in eval_data],
)
return eval_results
def _process_policy_eval_results(
self,
to_eval: Dict[PolicyID, List[_PolicyEvalData]],
eval_results: Dict[PolicyID, PolicyOutputType],
off_policy_actions: MultiEnvDict,
):
"""Process the output of policy neural network evaluation.
Records policy evaluation results into agent connectors and
returns replies to send back to agents in the env.
Args:
to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects.
eval_results: Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts.
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call.
Returns:
Nested dict of env id -> agent id -> actions to be sent to
Env (np.ndarrays).
"""
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
for eval_data in to_eval.values():
for d in eval_data:
actions_to_send[d.env_id] = {} # at minimum send empty dict
# types: PolicyID, List[_PolicyEvalData]
for policy_id, eval_data in to_eval.items():
actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions)
rnn_out: StateBatches = eval_results[policy_id][1]
extra_action_out: dict = eval_results[policy_id][2]
# In case actions is a list (representing the 0th dim of a batch of
# primitive actions), try converting it first.
if isinstance(actions, list):
actions = np.array(actions)
# Split action-component batches into single action rows.
actions: List[EnvActionType] = unbatch(actions)
policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
assert (
policy.agent_connectors and policy.action_connectors
), "EnvRunnerV2 requires action connectors to work."
# types: int, EnvActionType
for i, action in enumerate(actions):
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
rnn_states: List[StateBatches] = [c[i] for c in rnn_out]
fetches: Dict = {k: v[i] for k, v in extra_action_out.items()}
# Post-process policy output by running them through action connectors.
ac_data = ActionConnectorDataType(
env_id, agent_id, (action, rnn_states, fetches)
)
action_to_send, rnn_states, fetches = policy.action_connectors(
ac_data
).output
action_to_buffer = (
action_to_send
if env_id not in off_policy_actions
or agent_id not in off_policy_actions[env_id]
else off_policy_actions[env_id][agent_id]
)
# Notify agent connectors with this new policy output.
# Necessary for state buffering agent connectors, for example.
ac_data: AgentConnectorDataType = ActionConnectorDataType(
env_id, agent_id, (action_to_buffer, rnn_states, fetches)
)
policy.agent_connectors.on_policy_output(ac_data)
assert agent_id not in actions_to_send[env_id]
actions_to_send[env_id][agent_id] = action_to_send
return actions_to_send
def _maybe_render(self):
"""Visualize environment."""
# Check if we should render.
if not self._render or not self._simple_image_viewer:
return
t5 = time.time()
# Render can either return an RGB image (uint8 [w x h x 3] numpy
# array) or take care of rendering itself (returning True).
rendered = self._base_env.try_render()
# Rendering returned an image -> Display it in a SimpleImageViewer.
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
self._simple_image_viewer.imshow(rendered)
elif rendered not in [True, False, None]:
raise ValueError(
f"The env's ({self._base_env}) `try_render()` method returned an"
" unsupported value! Make sure you either return a "
"uint8/w x h x 3 (RGB) image or handle rendering in a "
"window and then return `True`."
)
self._perf_stats.env_render_time += time.time() - t5
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
"""Atari games have multiple logical episodes, one per life.
However, for metrics reporting we count full episodes, all lives included.
"""
sub_environments = base_env.get_sub_environments()
if not sub_environments:
return None
atari_out = []
for sub_env in sub_environments:
monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew))
return atari_out
def _get_or_raise(
mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
) -> Union[Policy, Preprocessor, Filter]:
"""Returns an object under key `policy_id` in `mapping`.
Args:
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
mapping dict from policy id (str) to actual object (Policy,
Preprocessor, etc.).
policy_id: The policy ID to lookup.
Returns:
Union[Policy, Preprocessor, Filter]: The found object.
Raises:
ValueError: If `policy_id` cannot be found in `mapping`.
"""
if policy_id not in mapping:
raise ValueError(
"Could not find policy for agent: PolicyID `{}` not found "
"in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
)
return mapping[policy_id]