mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
249 lines
11 KiB
Python
249 lines
11 KiB
Python
import logging
|
|
from typing import Dict, Optional, TYPE_CHECKING
|
|
|
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
|
from ray.rllib.evaluation.per_policy_sample_collector import \
|
|
_PerPolicySampleCollector
|
|
from ray.rllib.evaluation.sample_collector import _SampleCollector
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
|
from ray.rllib.utils import force_list
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.debug import summarize
|
|
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
|
TensorType
|
|
from ray.util.debug import log_once
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _MultiAgentSampleCollector(_SampleCollector):
|
|
"""Builds SampleBatches for each policy (and agent) in a multi-agent env.
|
|
|
|
Note: This is an experimental class only used when
|
|
`config._use_trajectory_view_api` = True.
|
|
Once `_use_trajectory_view_api` becomes the default in configs:
|
|
This class will deprecate the `SampleBatchBuilder` class.
|
|
|
|
Input data is collected in central per-policy buffers, which
|
|
efficiently pre-allocate memory (over n timesteps) and re-use the same
|
|
memory even for succeeding agents and episodes.
|
|
Input_dicts for action computations, SampleBatches for postprocessing, and
|
|
train_batch dicts are - if possible - created from the central per-policy
|
|
buffers via views to avoid copying of data).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy_map: Dict[PolicyID, Policy],
|
|
callbacks: "DefaultCallbacks",
|
|
# TODO: (sven) make `num_agents` flexibly grow in size.
|
|
num_agents: int = 100,
|
|
num_timesteps=None,
|
|
time_major: Optional[bool] = False):
|
|
"""Initializes a _MultiAgentSampleCollector object.
|
|
|
|
Args:
|
|
policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy
|
|
instances.
|
|
callbacks (DefaultCallbacks): RLlib callbacks (configured in the
|
|
Trainer config dict). Used for trajectory postprocessing event.
|
|
num_agents (int): The max number of agent slots to pre-allocate
|
|
in the buffer.
|
|
num_timesteps (int): The max number of timesteps to pre-allocate
|
|
in the buffer.
|
|
time_major (Optional[bool]): Whether to preallocate buffers and
|
|
collect samples in time-major fashion (TxBx...).
|
|
"""
|
|
|
|
self.policy_map = policy_map
|
|
self.callbacks = callbacks
|
|
if num_agents == float("inf") or num_agents is None:
|
|
num_agents = 1000
|
|
self.num_agents = int(num_agents)
|
|
|
|
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
|
|
self.policy_sample_collectors = {}
|
|
for pid, policy in policy_map.items():
|
|
# Figure out max-shifts (before and after).
|
|
view_reqs = policy.training_view_requirements
|
|
max_shift_before = 0
|
|
max_shift_after = 0
|
|
for vr in view_reqs.values():
|
|
shift = force_list(vr.shift)
|
|
if max_shift_before > shift[0]:
|
|
max_shift_before = shift[0]
|
|
if max_shift_after < shift[-1]:
|
|
max_shift_after = shift[-1]
|
|
# Figure out num_timesteps and num_agents.
|
|
kwargs = {"time_major": time_major}
|
|
if policy.is_recurrent():
|
|
kwargs["num_timesteps"] = \
|
|
policy.config["model"]["max_seq_len"]
|
|
kwargs["time_major"] = True
|
|
elif num_timesteps is not None:
|
|
kwargs["num_timesteps"] = num_timesteps
|
|
|
|
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
|
|
num_agents=self.num_agents,
|
|
shift_before=-max_shift_before,
|
|
shift_after=max_shift_after,
|
|
**kwargs)
|
|
|
|
# Internal agent-to-policy map.
|
|
self.agent_to_policy = {}
|
|
# Number of "inference" steps taken in the environment.
|
|
# Regardless of the number of agents involved in each of these steps.
|
|
self.count = 0
|
|
|
|
@override(_SampleCollector)
|
|
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
|
env_id: EnvID, policy_id: PolicyID,
|
|
obs: TensorType) -> None:
|
|
# Make sure our mappings are up to date.
|
|
if agent_id not in self.agent_to_policy:
|
|
self.agent_to_policy[agent_id] = policy_id
|
|
else:
|
|
assert self.agent_to_policy[agent_id] == policy_id
|
|
|
|
# Add initial obs to Trajectory.
|
|
self.policy_sample_collectors[policy_id].add_init_obs(
|
|
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
|
|
|
|
@override(_SampleCollector)
|
|
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
|
agent_id: AgentID, env_id: EnvID,
|
|
policy_id: PolicyID, agent_done: bool,
|
|
values: Dict[str, TensorType]) -> None:
|
|
assert policy_id in self.policy_sample_collectors
|
|
|
|
# Make sure our mappings are up to date.
|
|
if agent_id not in self.agent_to_policy:
|
|
self.agent_to_policy[agent_id] = policy_id
|
|
else:
|
|
assert self.agent_to_policy[agent_id] == policy_id
|
|
|
|
# Include the current agent id for multi-agent algorithms.
|
|
if agent_id != _DUMMY_AGENT_ID:
|
|
values["agent_id"] = agent_id
|
|
|
|
# Add action/reward/next-obs (and other data) to Trajectory.
|
|
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
|
|
episode_id, agent_id, env_id, agent_done, values)
|
|
|
|
@override(_SampleCollector)
|
|
def total_env_steps(self) -> int:
|
|
return sum(a.timesteps_since_last_reset
|
|
for a in self.policy_sample_collectors.values())
|
|
|
|
def total(self):
|
|
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
|
|
# Sampler is currently still using `total()`.
|
|
return self.total_env_steps()
|
|
|
|
@override(_SampleCollector)
|
|
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
|
Dict[str, TensorType]:
|
|
policy = self.policy_map[policy_id]
|
|
view_reqs = policy.model.inference_view_requirements
|
|
return self.policy_sample_collectors[
|
|
policy_id].get_inference_input_dict(view_reqs)
|
|
|
|
@override(_SampleCollector)
|
|
def has_non_postprocessed_data(self) -> bool:
|
|
return self.total_env_steps() > 0
|
|
|
|
@override(_SampleCollector)
|
|
def postprocess_trajectories_so_far(
|
|
self, episode: Optional[MultiAgentEpisode] = None) -> None:
|
|
# Loop through each per-policy collector and create a view (for each
|
|
# agent as SampleBatch) from its buffers for post-processing
|
|
all_agent_batches = {}
|
|
for pid, rc in self.policy_sample_collectors.items():
|
|
policy = self.policy_map[pid]
|
|
view_reqs = policy.training_view_requirements
|
|
agent_batches = rc.get_postprocessing_sample_batches(
|
|
episode, view_reqs)
|
|
|
|
for agent_key, batch in agent_batches.items():
|
|
other_batches = None
|
|
if len(agent_batches) > 1:
|
|
other_batches = agent_batches.copy()
|
|
del other_batches[agent_key]
|
|
|
|
agent_batches[agent_key] = policy.postprocess_trajectory(
|
|
batch, other_batches, episode)
|
|
# Call the Policy's Exploration's postprocess method.
|
|
if getattr(policy, "exploration", None) is not None:
|
|
agent_batches[
|
|
agent_key] = policy.exploration.postprocess_trajectory(
|
|
policy, agent_batches[agent_key],
|
|
getattr(policy, "_sess", None))
|
|
|
|
# Add new columns' data to buffers.
|
|
for col in agent_batches[agent_key].new_columns:
|
|
data = agent_batches[agent_key].data[col]
|
|
rc._build_buffers({col: data[0]})
|
|
timesteps = data.shape[0]
|
|
rc.buffers[col][rc.shift_before:rc.shift_before +
|
|
timesteps, rc.agent_key_to_slot[
|
|
agent_key]] = data
|
|
|
|
all_agent_batches.update(agent_batches)
|
|
|
|
if log_once("after_post"):
|
|
logger.info("Trajectory fragment after postprocess_trajectory():"
|
|
"\n\n{}\n".format(summarize(all_agent_batches)))
|
|
|
|
# Append into policy batches and reset
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|
for agent_key, batch in sorted(all_agent_batches.items()):
|
|
self.callbacks.on_postprocess_trajectory(
|
|
worker=get_global_worker(),
|
|
episode=episode,
|
|
agent_id=agent_key[0],
|
|
policy_id=self.agent_to_policy[agent_key[0]],
|
|
policies=self.policy_map,
|
|
postprocessed_batch=batch,
|
|
original_batches=None) # TODO: (sven) do we really need this?
|
|
|
|
@override(_SampleCollector)
|
|
def check_missing_dones(self, episode_id: EpisodeID) -> None:
|
|
for pid, rc in self.policy_sample_collectors.items():
|
|
for agent_key in rc.agent_key_to_slot.keys():
|
|
# Only check for given episode and only for last chunk
|
|
# (all previous chunks for that agent in the episode are
|
|
# non-terminal).
|
|
if (agent_key[1] == episode_id
|
|
and rc.agent_key_to_chunk_num[agent_key[:2]] ==
|
|
agent_key[2]):
|
|
t = rc.agent_key_to_timestep[agent_key] - 1
|
|
b = rc.agent_key_to_slot[agent_key]
|
|
if not rc.buffers["dones"][t][b]:
|
|
raise ValueError(
|
|
"Episode {} terminated for all agents, but we "
|
|
"still don't have a last observation for "
|
|
"agent {} (policy {}). ".format(agent_key[0], pid)
|
|
+ "Please ensure that you include the last "
|
|
"observations of all live agents when setting "
|
|
"'__all__' done to True. Alternatively, set "
|
|
"no_done_at_end=True to allow this.")
|
|
|
|
@override(_SampleCollector)
|
|
def get_multi_agent_batch_and_reset(self):
|
|
self.postprocess_trajectories_so_far()
|
|
policy_batches = {}
|
|
for pid, rc in self.policy_sample_collectors.items():
|
|
policy = self.policy_map[pid]
|
|
view_reqs = policy.training_view_requirements
|
|
policy_batches[pid] = rc.get_train_sample_batch_and_reset(
|
|
view_reqs)
|
|
|
|
ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count)
|
|
# Reset our across-all-agents env step count.
|
|
self.count = 0
|
|
return ma_batch
|