ray/rllib/evaluation/multi_agent_sample_collector.py

249 lines
11 KiB
Python

import logging
from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.per_policy_sample_collector import \
_PerPolicySampleCollector
from ray.rllib.evaluation.sample_collector import _SampleCollector
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.rllib.agents.callbacks import DefaultCallbacks
logger = logging.getLogger(__name__)
class _MultiAgentSampleCollector(_SampleCollector):
"""Builds SampleBatches for each policy (and agent) in a multi-agent env.
Note: This is an experimental class only used when
`config._use_trajectory_view_api` = True.
Once `_use_trajectory_view_api` becomes the default in configs:
This class will deprecate the `SampleBatchBuilder` class.
Input data is collected in central per-policy buffers, which
efficiently pre-allocate memory (over n timesteps) and re-use the same
memory even for succeeding agents and episodes.
Input_dicts for action computations, SampleBatches for postprocessing, and
train_batch dicts are - if possible - created from the central per-policy
buffers via views to avoid copying of data).
"""
def __init__(
self,
policy_map: Dict[PolicyID, Policy],
callbacks: "DefaultCallbacks",
# TODO: (sven) make `num_agents` flexibly grow in size.
num_agents: int = 100,
num_timesteps=None,
time_major: Optional[bool] = False):
"""Initializes a _MultiAgentSampleCollector object.
Args:
policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy
instances.
callbacks (DefaultCallbacks): RLlib callbacks (configured in the
Trainer config dict). Used for trajectory postprocessing event.
num_agents (int): The max number of agent slots to pre-allocate
in the buffer.
num_timesteps (int): The max number of timesteps to pre-allocate
in the buffer.
time_major (Optional[bool]): Whether to preallocate buffers and
collect samples in time-major fashion (TxBx...).
"""
self.policy_map = policy_map
self.callbacks = callbacks
if num_agents == float("inf") or num_agents is None:
num_agents = 1000
self.num_agents = int(num_agents)
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
self.policy_sample_collectors = {}
for pid, policy in policy_map.items():
# Figure out max-shifts (before and after).
view_reqs = policy.training_view_requirements
max_shift_before = 0
max_shift_after = 0
for vr in view_reqs.values():
shift = force_list(vr.shift)
if max_shift_before > shift[0]:
max_shift_before = shift[0]
if max_shift_after < shift[-1]:
max_shift_after = shift[-1]
# Figure out num_timesteps and num_agents.
kwargs = {"time_major": time_major}
if policy.is_recurrent():
kwargs["num_timesteps"] = \
policy.config["model"]["max_seq_len"]
kwargs["time_major"] = True
elif num_timesteps is not None:
kwargs["num_timesteps"] = num_timesteps
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
num_agents=self.num_agents,
shift_before=-max_shift_before,
shift_after=max_shift_after,
**kwargs)
# Internal agent-to-policy map.
self.agent_to_policy = {}
# Number of "inference" steps taken in the environment.
# Regardless of the number of agents involved in each of these steps.
self.count = 0
@override(_SampleCollector)
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
obs: TensorType) -> None:
# Make sure our mappings are up to date.
if agent_id not in self.agent_to_policy:
self.agent_to_policy[agent_id] = policy_id
else:
assert self.agent_to_policy[agent_id] == policy_id
# Add initial obs to Trajectory.
self.policy_sample_collectors[policy_id].add_init_obs(
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
@override(_SampleCollector)
def add_action_reward_next_obs(self, episode_id: EpisodeID,
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
values: Dict[str, TensorType]) -> None:
assert policy_id in self.policy_sample_collectors
# Make sure our mappings are up to date.
if agent_id not in self.agent_to_policy:
self.agent_to_policy[agent_id] = policy_id
else:
assert self.agent_to_policy[agent_id] == policy_id
# Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id
# Add action/reward/next-obs (and other data) to Trajectory.
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
episode_id, agent_id, env_id, agent_done, values)
@override(_SampleCollector)
def total_env_steps(self) -> int:
return sum(a.timesteps_since_last_reset
for a in self.policy_sample_collectors.values())
def total(self):
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
# Sampler is currently still using `total()`.
return self.total_env_steps()
@override(_SampleCollector)
def get_inference_input_dict(self, policy_id: PolicyID) -> \
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements
return self.policy_sample_collectors[
policy_id].get_inference_input_dict(view_reqs)
@override(_SampleCollector)
def has_non_postprocessed_data(self) -> bool:
return self.total_env_steps() > 0
@override(_SampleCollector)
def postprocess_trajectories_so_far(
self, episode: Optional[MultiAgentEpisode] = None) -> None:
# Loop through each per-policy collector and create a view (for each
# agent as SampleBatch) from its buffers for post-processing
all_agent_batches = {}
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
agent_batches = rc.get_postprocessing_sample_batches(
episode, view_reqs)
for agent_key, batch in agent_batches.items():
other_batches = None
if len(agent_batches) > 1:
other_batches = agent_batches.copy()
del other_batches[agent_key]
agent_batches[agent_key] = policy.postprocess_trajectory(
batch, other_batches, episode)
# Call the Policy's Exploration's postprocess method.
if getattr(policy, "exploration", None) is not None:
agent_batches[
agent_key] = policy.exploration.postprocess_trajectory(
policy, agent_batches[agent_key],
getattr(policy, "_sess", None))
# Add new columns' data to buffers.
for col in agent_batches[agent_key].new_columns:
data = agent_batches[agent_key].data[col]
rc._build_buffers({col: data[0]})
timesteps = data.shape[0]
rc.buffers[col][rc.shift_before:rc.shift_before +
timesteps, rc.agent_key_to_slot[
agent_key]] = data
all_agent_batches.update(agent_batches)
if log_once("after_post"):
logger.info("Trajectory fragment after postprocess_trajectory():"
"\n\n{}\n".format(summarize(all_agent_batches)))
# Append into policy batches and reset
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_key, batch in sorted(all_agent_batches.items()):
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=episode,
agent_id=agent_key[0],
policy_id=self.agent_to_policy[agent_key[0]],
policies=self.policy_map,
postprocessed_batch=batch,
original_batches=None) # TODO: (sven) do we really need this?
@override(_SampleCollector)
def check_missing_dones(self, episode_id: EpisodeID) -> None:
for pid, rc in self.policy_sample_collectors.items():
for agent_key in rc.agent_key_to_slot.keys():
# Only check for given episode and only for last chunk
# (all previous chunks for that agent in the episode are
# non-terminal).
if (agent_key[1] == episode_id
and rc.agent_key_to_chunk_num[agent_key[:2]] ==
agent_key[2]):
t = rc.agent_key_to_timestep[agent_key] - 1
b = rc.agent_key_to_slot[agent_key]
if not rc.buffers["dones"][t][b]:
raise ValueError(
"Episode {} terminated for all agents, but we "
"still don't have a last observation for "
"agent {} (policy {}). ".format(agent_key[0], pid)
+ "Please ensure that you include the last "
"observations of all live agents when setting "
"'__all__' done to True. Alternatively, set "
"no_done_at_end=True to allow this.")
@override(_SampleCollector)
def get_multi_agent_batch_and_reset(self):
self.postprocess_trajectories_so_far()
policy_batches = {}
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
policy_batches[pid] = rc.get_train_sample_batch_and_reset(
view_reqs)
ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count)
# Reset our across-all-agents env step count.
self.count = 0
return ma_batch