ray/rllib/evaluation/episode_v2.py

347 lines
13 KiB
Python

import random
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.simple_list_collector import (
_PolicyCollector,
_PolicyCollectorGroup,
)
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentID, EnvID, PolicyID, TensorType
if TYPE_CHECKING:
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation.rollout_worker import RolloutWorker
@DeveloperAPI
class EpisodeV2:
"""Tracks the current state of a (possibly multi-agent) episode."""
def __init__(
self,
env_id: EnvID,
policies: PolicyMap,
policy_mapping_fn: Callable[[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID],
*,
worker: Optional["RolloutWorker"] = None,
callbacks: Optional["DefaultCallbacks"] = None,
):
"""Initializes an Episode instance.
Args:
env_id: The environment's ID in which this episode runs.
policies: The PolicyMap object (mapping PolicyIDs to Policy
objects) to use for determining, which policy is used for
which agent.
policy_mapping_fn: The mapping function mapping AgentIDs to
PolicyIDs.
worker: The RolloutWorker instance, in which this episode runs.
"""
# Unique id identifying this trajectory.
self.episode_id: int = random.randrange(2e9)
# ID of the environment this episode is tracking.
self.env_id = env_id
# Summed reward across all agents in this episode.
self.total_reward: float = 0.0
# Active (uncollected) # of env steps taken by this episode.
# Start from -1. After add_init_obs(), we will be at 0 step.
self.active_env_steps: int = -1
# Total # of env steps taken by this episode.
# Start from -1, After add_init_obs(), we will be at 0 step.
self.total_env_steps: int = -1
# Active (uncollected) agent steps.
self.active_agent_steps: int = 0
# Total # of steps take by all agents in this env.
self.total_agent_steps: int = 0
# Dict for user to add custom metrics.
# TODO(jungong) : we should probably unify custom_metrics, user_data,
# and hist_data into a single data container for user to track per-step
# metrics and states.
self.custom_metrics: Dict[str, float] = {}
# Temporary storage. E.g. storing data in between two custom
# callbacks referring to the same episode.
self.user_data: Dict[str, Any] = {}
# Dict mapping str keys to List[float] for storage of
# per-timestep float data throughout the episode.
self.hist_data: Dict[str, List[float]] = {}
self.media: Dict[str, Any] = {}
self.worker = worker
self.callbacks = callbacks
self.policy_map: PolicyMap = policies
self.policy_mapping_fn: Callable[
[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID
] = policy_mapping_fn
# Per-agent data collectors.
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
self._agent_collectors: Dict[AgentID, AgentCollector] = {}
self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {}
# Summed rewards broken down by agent.
self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float)
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list)
self._has_init_obs: Dict[AgentID, bool] = {}
self._last_dones: Dict[AgentID, bool] = {}
# Keep last info dict around, in case an environment tries to signal
# us something.
self._last_infos: Dict[AgentID, Dict] = {}
@DeveloperAPI
def policy_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID, refresh: bool = False
) -> PolicyID:
"""Returns and stores the policy ID for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the entire episode (even if the
policy_mapping_fn is changed in the meantime!).
Args:
agent_id: The agent ID to lookup the policy ID for.
Returns:
The policy ID for the specified agent.
"""
# Perform a new policy_mapping_fn lookup and bind AgentID for the
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy or refresh:
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
agent_id, self, worker=self.worker
)
# Use already determined PolicyID.
else:
policy_id = self._agent_to_policy[agent_id]
# PolicyID not found in policy map -> Error.
if policy_id not in self.policy_map:
raise KeyError(
"policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
)
return policy_id
@DeveloperAPI
def get_agents(self) -> List[AgentID]:
"""Returns list of agent IDs that have appeared in this episode.
Returns:
The list of all agent IDs that have appeared so far in this
episode.
"""
return list(self._agent_to_index.keys())
def agent_index(self, agent_id: AgentID) -> int:
"""Get the index of an agent among its environment.
A new index will be created if an agent is seen for the first time.
Args:
agent_id: ID of an agent.
Returns:
The index of this agent.
"""
if agent_id not in self._agent_to_index:
self._agent_to_index[agent_id] = self._next_agent_index
self._next_agent_index += 1
return self._agent_to_index[agent_id]
def step(self) -> None:
"""Advance the episode forward by one step."""
self.active_env_steps += 1
self.total_env_steps += 1
def add_init_obs(
self,
agent_id: AgentID,
t: int,
init_obs: TensorType,
) -> None:
"""Add initial env obs at the start of a new episode
Args:
agent_id: Agent ID.
t: timestamp.
init_obs: Initial observations.
"""
policy = self.policy_map[self.policy_for(agent_id)]
# Add initial obs to Trajectory.
assert agent_id not in self._agent_collectors
self._agent_collectors[agent_id] = AgentCollector(
policy.view_requirements,
max_seq_len=policy.config["model"]["max_seq_len"],
disable_action_flattening=policy.config.get(
"_disable_action_flattening", False
),
is_policy_recurrent=policy.is_recurrent(),
)
self._agent_collectors[agent_id].add_init_obs(
episode_id=self.episode_id,
agent_index=self.agent_index(agent_id),
env_id=self.env_id,
t=t,
init_obs=init_obs,
)
self._has_init_obs[agent_id] = True
def add_action_reward_done_next_obs(
self,
agent_id: AgentID,
values: Dict[str, TensorType],
) -> None:
"""Add action, reward, info, and next_obs as a new step.
Args:
agent_id: Agent ID.
values: Dict of action, reward, info, and next_obs.
"""
# Make sure, agent already has some (at least init) data.
assert agent_id in self._agent_collectors
self.active_agent_steps += 1
self.total_agent_steps += 1
# Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id
# Add action/reward/next-obs (and other data) to Trajectory.
self._agent_collectors[agent_id].add_action_reward_next_obs(values)
# Keep track of agent reward history.
reward = values[SampleBatch.REWARDS]
self.total_reward += reward
self.agent_rewards[(agent_id, self.policy_for(agent_id))] += reward
self._agent_reward_history[agent_id].append(reward)
# Keep track of last done info for agent.
if SampleBatch.DONES in values:
self._last_dones[agent_id] = values[SampleBatch.DONES]
# Keep track of last info dict if available.
if SampleBatch.INFOS in values:
self.set_last_info(agent_id, values[SampleBatch.INFOS])
def postprocess_episode(
self,
batch_builder: _PolicyCollectorGroup,
is_done: bool = False,
check_dones: bool = False,
) -> None:
"""Build and return currently collected training samples by policies.
Clear agent collector states if this episode is done.
Args:
batch_builder: _PolicyCollectorGroup for saving the collected per-agent
sample batches.
is_done: If this episode is done.
check_dones: Whether to make sure per-agent trajectories are actually done.
"""
# TODO: (sven) Once we implement multi-agent communication channels,
# we have to resolve the restriction of only sending other agent
# batches from the same policy to the postprocess methods.
# Build SampleBatches for the given episode.
pre_batches = {}
for agent_id, collector in self._agent_collectors.items():
# Build only if there is data and agent is part of given episode.
if collector.agent_steps == 0:
continue
pid = self.policy_for(agent_id)
policy = self.policy_map[pid]
pre_batch = collector.build_for_training(policy.view_requirements)
pre_batches[agent_id] = (pid, policy, pre_batch)
for agent_id, (pid, policy, pre_batch) in pre_batches.items():
# Entire episode is said to be done.
# Error if no DONE at end of this agent's trajectory.
if is_done and check_dones and not pre_batch[SampleBatch.DONES][-1]:
raise ValueError(
"Episode {} terminated for all agents, but we still "
"don't have a last observation for agent {} (policy "
"{}). ".format(self.episode_id, agent_id, self.policy_for(agent_id))
+ "Please ensure that you include the last observations "
"of all live agents when setting done[__all__] to "
"True. Alternatively, set no_done_at_end=True to "
"allow this."
)
# Skip a trajectory's postprocessing (and thus using it for training),
# if its agent's info exists and contains the training_enabled=False
# setting (used by our PolicyClients).
if not self._last_infos.get(agent_id, {}).get("training_enabled", True):
continue
if (
any(pre_batch[SampleBatch.DONES][:-1])
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
):
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.",
pre_batch,
)
if len(pre_batches) > 1:
other_batches = pre_batches.copy()
del other_batches[agent_id]
else:
other_batches = {}
# Call the Policy's Exploration's postprocess method.
post_batch = pre_batch
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batch, policy.get_session()
)
post_batch.set_get_interceptor(None)
post_batch = policy.postprocess_trajectory(post_batch, other_batches, self)
from ray.rllib.evaluation.rollout_worker import get_global_worker
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=self,
agent_id=agent_id,
policy_id=pid,
policies=self.policy_map,
postprocessed_batch=post_batch,
original_batches=pre_batches,
)
# Append post_batch for return.
if pid not in batch_builder.policy_collectors:
batch_builder.policy_collectors[pid] = _PolicyCollector(policy)
batch_builder.policy_collectors[pid].add_postprocessed_batch_for_training(
post_batch, policy.view_requirements
)
batch_builder.agent_steps += self.active_agent_steps
batch_builder.env_steps += self.active_env_steps
# AgentCollector cleared.
self.active_agent_steps = 0
self.active_env_steps = 0
def has_init_obs(self, agent_id: AgentID) -> bool:
return agent_id in self._has_init_obs and self._has_init_obs[agent_id]
def is_done(self, agent_id: AgentID) -> bool:
return self._last_dones.get(agent_id, False)
def set_last_info(self, agent_id: AgentID, info: Dict):
self._last_infos[agent_id] = info
@property
def length(self):
return self.total_env_steps