mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
259 lines
12 KiB
Python
259 lines
12 KiB
Python
from abc import abstractmethod, ABCMeta
|
|
import logging
|
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
|
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
|
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
|
TensorType
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
class SampleCollector(metaclass=ABCMeta):
|
|
"""Collects samples for all policies and agents from a multi-agent env.
|
|
|
|
This API is controlled by RolloutWorker objects to store all data
|
|
generated by Environments and Policies/Models during rollout and
|
|
postprocessing. It's purposes are to a) make data collection and
|
|
SampleBatch/input_dict generation from this data faster, b) to unify
|
|
the way we collect samples from environments and model (outputs), thereby
|
|
allowing for possible user customizations, c) to allow for more complex
|
|
inputs fed into different policies (e.g. multi-agent case with inter-agent
|
|
communication channel).
|
|
"""
|
|
|
|
def __init__(self,
|
|
policy_map: Dict[PolicyID, Policy],
|
|
clip_rewards: Union[bool, float],
|
|
callbacks: "DefaultCallbacks",
|
|
multiple_episodes_in_batch: bool = True,
|
|
rollout_fragment_length: int = 200,
|
|
count_steps_by: str = "env_steps"):
|
|
"""Initializes a SampleCollector instance.
|
|
|
|
Args:
|
|
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
|
instances.
|
|
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
|
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
|
callbacks (DefaultCallbacks): RLlib callbacks.
|
|
multiple_episodes_in_batch (bool): Whether it's allowed to pack
|
|
multiple episodes into the same built batch.
|
|
rollout_fragment_length (int): The
|
|
|
|
"""
|
|
|
|
self.policy_map = policy_map
|
|
self.clip_rewards = clip_rewards
|
|
self.callbacks = callbacks
|
|
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
|
self.rollout_fragment_length = rollout_fragment_length
|
|
self.count_steps_by = count_steps_by
|
|
|
|
@abstractmethod
|
|
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
|
policy_id: PolicyID, t: int,
|
|
init_obs: TensorType) -> None:
|
|
"""Adds an initial obs (after reset) to this collector.
|
|
|
|
Since the very first observation in an environment is collected w/o
|
|
additional data (w/o actions, w/o reward) after env.reset() is called,
|
|
this method initializes a new trajectory for a given agent.
|
|
`add_init_obs()` has to be called first for each agent/episode-ID
|
|
combination. After this, only `add_action_reward_next_obs()` must be
|
|
called for that same agent/episode-pair.
|
|
|
|
Args:
|
|
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we
|
|
are adding an Agent's initial observation.
|
|
agent_id (AgentID): Unique id for the agent we are adding
|
|
values for.
|
|
env_id (EnvID): The environment index (in a vectorized setup).
|
|
policy_id (PolicyID): Unique id for policy controlling the agent.
|
|
t (int): The time step (episode length - 1). The initial obs has
|
|
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
|
init_obs (TensorType): Initial observation (after env.reset()).
|
|
|
|
Examples:
|
|
>>> obs = env.reset()
|
|
>>> collector.add_init_obs(my_episode, 0, "pol0", -1, obs)
|
|
>>> obs, r, done, info = env.step(action)
|
|
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
|
... "action": action, "obs": obs, "reward": r, "done": done
|
|
... })
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
|
agent_id: AgentID, env_id: EnvID,
|
|
policy_id: PolicyID, agent_done: bool,
|
|
values: Dict[str, TensorType]) -> None:
|
|
"""Add the given dictionary (row) of values to this collector.
|
|
|
|
The incoming data (`values`) must include action, reward, done, and
|
|
next_obs information and may include any other information.
|
|
For the initial observation (after Env.reset()) of the given agent/
|
|
episode-ID combination, `add_initial_obs()` must be called instead.
|
|
|
|
Args:
|
|
episode_id (EpisodeID): Unique id for the episode we are adding
|
|
values for.
|
|
agent_id (AgentID): Unique id for the agent we are adding
|
|
values for.
|
|
env_id (EnvID): The environment index (in a vectorized setup).
|
|
policy_id (PolicyID): Unique id for policy controlling the agent.
|
|
agent_done (bool): Whether the given agent is done with its
|
|
trajectory (the multi-agent episode may still be ongoing).
|
|
values (Dict[str, TensorType]): Row of values to add for this
|
|
agent. This row must contain the keys SampleBatch.ACTION,
|
|
REWARD, NEW_OBS, and DONE.
|
|
|
|
Examples:
|
|
>>> obs = env.reset()
|
|
>>> collector.add_init_obs(12345, 0, "pol0", obs)
|
|
>>> obs, r, done, info = env.step(action)
|
|
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
|
... "action": action, "obs": obs, "reward": r, "done": done
|
|
... })
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def episode_step(self, episode: MultiAgentEpisode) -> None:
|
|
"""Increases the episode step counter (across all agents) by one.
|
|
|
|
Args:
|
|
episode (MultiAgentEpisode): Episode we are stepping through.
|
|
Useful for handling counting b/c it is called once across
|
|
all agents that are inside this episode.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def total_env_steps(self) -> int:
|
|
"""Returns total number of env-steps taken so far.
|
|
|
|
Thereby, a step in an N-agent multi-agent environment counts as only 1
|
|
for this metric. The returned count contains everything that has not
|
|
been built yet (and returned as MultiAgentBatches by the
|
|
`try_build_truncated_episode_multi_agent_batch` or
|
|
`postprocess_episode(build=True)` methods). After such build, this
|
|
counter is reset to 0.
|
|
|
|
Returns:
|
|
int: The number of env-steps taken in total in the environment(s)
|
|
so far.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def total_agent_steps(self) -> int:
|
|
"""Returns total number of (individual) agent-steps taken so far.
|
|
|
|
Thereby, a step in an N-agent multi-agent environment counts as N.
|
|
If less than N agents have stepped (because some agents were not
|
|
required to send actions), the count will be increased by less than N.
|
|
The returned count contains everything that has not been built yet
|
|
(and returned as MultiAgentBatches by the
|
|
`try_build_truncated_episode_multi_agent_batch` or
|
|
`postprocess_episode(build=True)` methods). After such build, this
|
|
counter is reset to 0.
|
|
|
|
Returns:
|
|
int: The number of (individual) agent-steps taken in total in the
|
|
environment(s) so far.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
|
Dict[str, TensorType]:
|
|
"""Returns an input_dict for an (inference) forward pass from our data.
|
|
|
|
The input_dict can then be used for action computations inside a
|
|
Policy via `Policy.compute_actions_from_input_dict()`.
|
|
|
|
Args:
|
|
policy_id (PolicyID): The Policy ID to get the input dict for.
|
|
|
|
Returns:
|
|
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
|
|
for inference/training.
|
|
|
|
Examples:
|
|
>>> obs, r, done, info = env.step(action)
|
|
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
|
|
... "action": action, "obs": obs, "reward": r, "done": done
|
|
... })
|
|
>>> input_dict = collector.get_inference_input_dict(policy.model)
|
|
>>> action = policy.compute_actions_from_input_dict(input_dict)
|
|
>>> # repeat
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def postprocess_episode(self,
|
|
episode: MultiAgentEpisode,
|
|
is_done: bool = False,
|
|
check_dones: bool = False,
|
|
build: bool = False) -> Optional[MultiAgentBatch]:
|
|
"""Postprocesses all agents' trajectories in a given episode.
|
|
|
|
Generates (single-trajectory) SampleBatches for all Policies/Agents and
|
|
calls Policy.postprocess_trajectory on each of these. Postprocessing
|
|
may happens in-place, meaning any changes to the viewed data columns
|
|
are directly reflected inside this collector's buffers.
|
|
Also makes sure that additional (newly created) data columns are
|
|
correctly added to the buffers.
|
|
|
|
Args:
|
|
episode (MultiAgentEpisode): The Episode object for which
|
|
to post-process data.
|
|
is_done (bool): Whether the given episode is actually terminated
|
|
(all agents are done OR we hit a hard horizon). If True, the
|
|
episode will no longer be used/continued and we may need to
|
|
recycle/erase it internally. If a soft-horizon is hit, the
|
|
episode will continue to be used and `is_done` should be set
|
|
to False here.
|
|
check_dones (bool): Whether we need to check that all agents'
|
|
trajectories have dones=True at the end.
|
|
build (bool): Whether to build a MultiAgentBatch from the given
|
|
episode (and only that episode!) and return that
|
|
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
|
|
|
|
Returns:
|
|
Optional[MultiAgentBatch]: If `build` is True, the
|
|
SampleBatch or MultiAgentBatch built from `episode` (either
|
|
just from that episde or from the `_PolicyCollectorGroup`
|
|
in the `episode.batch_builder` property).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def try_build_truncated_episode_multi_agent_batch(self) -> \
|
|
List[Union[MultiAgentBatch, SampleBatch]]:
|
|
"""Tries to build an MA-batch, if `rollout_fragment_length` is reached.
|
|
|
|
Any unprocessed data will be first postprocessed with a policy
|
|
postprocessor.
|
|
This is usually called to collect samples for policy training.
|
|
If not enough data has been collected yet (`rollout_fragment_length`),
|
|
returns an empty list.
|
|
|
|
Returns:
|
|
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
|
|
empty) list of MultiAgentBatches (containing the accumulated
|
|
SampleBatches for each policy or a simple SampleBatch if only
|
|
one policy). The list will be empty if
|
|
`self.rollout_fragment_length` has not been reached yet.
|
|
"""
|
|
raise NotImplementedError
|
|
# __sphinx_doc_end__
|