mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
347 lines
13 KiB
Python
347 lines
13 KiB
Python
import random
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
|
from ray.rllib.evaluation.collectors.simple_list_collector import (
|
|
_PolicyCollector,
|
|
_PolicyCollectorGroup,
|
|
)
|
|
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
|
from ray.rllib.policy.policy_map import PolicyMap
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.typing import AgentID, EnvID, PolicyID, TensorType
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
|
|
|
|
@DeveloperAPI
|
|
class EpisodeV2:
|
|
"""Tracks the current state of a (possibly multi-agent) episode."""
|
|
|
|
def __init__(
|
|
self,
|
|
env_id: EnvID,
|
|
policies: PolicyMap,
|
|
policy_mapping_fn: Callable[[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID],
|
|
*,
|
|
worker: Optional["RolloutWorker"] = None,
|
|
callbacks: Optional["DefaultCallbacks"] = None,
|
|
):
|
|
"""Initializes an Episode instance.
|
|
|
|
Args:
|
|
env_id: The environment's ID in which this episode runs.
|
|
policies: The PolicyMap object (mapping PolicyIDs to Policy
|
|
objects) to use for determining, which policy is used for
|
|
which agent.
|
|
policy_mapping_fn: The mapping function mapping AgentIDs to
|
|
PolicyIDs.
|
|
worker: The RolloutWorker instance, in which this episode runs.
|
|
"""
|
|
# Unique id identifying this trajectory.
|
|
self.episode_id: int = random.randrange(2e9)
|
|
# ID of the environment this episode is tracking.
|
|
self.env_id = env_id
|
|
# Summed reward across all agents in this episode.
|
|
self.total_reward: float = 0.0
|
|
# Active (uncollected) # of env steps taken by this episode.
|
|
# Start from -1. After add_init_obs(), we will be at 0 step.
|
|
self.active_env_steps: int = -1
|
|
# Total # of env steps taken by this episode.
|
|
# Start from -1, After add_init_obs(), we will be at 0 step.
|
|
self.total_env_steps: int = -1
|
|
# Active (uncollected) agent steps.
|
|
self.active_agent_steps: int = 0
|
|
# Total # of steps take by all agents in this env.
|
|
self.total_agent_steps: int = 0
|
|
# Dict for user to add custom metrics.
|
|
# TODO(jungong) : we should probably unify custom_metrics, user_data,
|
|
# and hist_data into a single data container for user to track per-step
|
|
# metrics and states.
|
|
self.custom_metrics: Dict[str, float] = {}
|
|
# Temporary storage. E.g. storing data in between two custom
|
|
# callbacks referring to the same episode.
|
|
self.user_data: Dict[str, Any] = {}
|
|
# Dict mapping str keys to List[float] for storage of
|
|
# per-timestep float data throughout the episode.
|
|
self.hist_data: Dict[str, List[float]] = {}
|
|
self.media: Dict[str, Any] = {}
|
|
|
|
self.worker = worker
|
|
self.callbacks = callbacks
|
|
|
|
self.policy_map: PolicyMap = policies
|
|
self.policy_mapping_fn: Callable[
|
|
[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID
|
|
] = policy_mapping_fn
|
|
# Per-agent data collectors.
|
|
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
|
|
self._agent_collectors: Dict[AgentID, AgentCollector] = {}
|
|
|
|
self._next_agent_index: int = 0
|
|
self._agent_to_index: Dict[AgentID, int] = {}
|
|
|
|
# Summed rewards broken down by agent.
|
|
self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float)
|
|
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list)
|
|
|
|
self._has_init_obs: Dict[AgentID, bool] = {}
|
|
self._last_dones: Dict[AgentID, bool] = {}
|
|
# Keep last info dict around, in case an environment tries to signal
|
|
# us something.
|
|
self._last_infos: Dict[AgentID, Dict] = {}
|
|
|
|
@DeveloperAPI
|
|
def policy_for(
|
|
self, agent_id: AgentID = _DUMMY_AGENT_ID, refresh: bool = False
|
|
) -> PolicyID:
|
|
"""Returns and stores the policy ID for the specified agent.
|
|
|
|
If the agent is new, the policy mapping fn will be called to bind the
|
|
agent to a policy for the duration of the entire episode (even if the
|
|
policy_mapping_fn is changed in the meantime!).
|
|
|
|
Args:
|
|
agent_id: The agent ID to lookup the policy ID for.
|
|
|
|
Returns:
|
|
The policy ID for the specified agent.
|
|
"""
|
|
|
|
# Perform a new policy_mapping_fn lookup and bind AgentID for the
|
|
# duration of this episode to the returned PolicyID.
|
|
if agent_id not in self._agent_to_policy or refresh:
|
|
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
|
|
agent_id, self, worker=self.worker
|
|
)
|
|
# Use already determined PolicyID.
|
|
else:
|
|
policy_id = self._agent_to_policy[agent_id]
|
|
|
|
# PolicyID not found in policy map -> Error.
|
|
if policy_id not in self.policy_map:
|
|
raise KeyError(
|
|
"policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
|
|
)
|
|
return policy_id
|
|
|
|
@DeveloperAPI
|
|
def get_agents(self) -> List[AgentID]:
|
|
"""Returns list of agent IDs that have appeared in this episode.
|
|
|
|
Returns:
|
|
The list of all agent IDs that have appeared so far in this
|
|
episode.
|
|
"""
|
|
return list(self._agent_to_index.keys())
|
|
|
|
def agent_index(self, agent_id: AgentID) -> int:
|
|
"""Get the index of an agent among its environment.
|
|
|
|
A new index will be created if an agent is seen for the first time.
|
|
|
|
Args:
|
|
agent_id: ID of an agent.
|
|
|
|
Returns:
|
|
The index of this agent.
|
|
"""
|
|
if agent_id not in self._agent_to_index:
|
|
self._agent_to_index[agent_id] = self._next_agent_index
|
|
self._next_agent_index += 1
|
|
return self._agent_to_index[agent_id]
|
|
|
|
def step(self) -> None:
|
|
"""Advance the episode forward by one step."""
|
|
self.active_env_steps += 1
|
|
self.total_env_steps += 1
|
|
|
|
def add_init_obs(
|
|
self,
|
|
agent_id: AgentID,
|
|
t: int,
|
|
init_obs: TensorType,
|
|
) -> None:
|
|
"""Add initial env obs at the start of a new episode
|
|
|
|
Args:
|
|
agent_id: Agent ID.
|
|
t: timestamp.
|
|
init_obs: Initial observations.
|
|
"""
|
|
policy = self.policy_map[self.policy_for(agent_id)]
|
|
|
|
# Add initial obs to Trajectory.
|
|
assert agent_id not in self._agent_collectors
|
|
self._agent_collectors[agent_id] = AgentCollector(
|
|
policy.view_requirements,
|
|
max_seq_len=policy.config["model"]["max_seq_len"],
|
|
disable_action_flattening=policy.config.get(
|
|
"_disable_action_flattening", False
|
|
),
|
|
is_policy_recurrent=policy.is_recurrent(),
|
|
)
|
|
self._agent_collectors[agent_id].add_init_obs(
|
|
episode_id=self.episode_id,
|
|
agent_index=self.agent_index(agent_id),
|
|
env_id=self.env_id,
|
|
t=t,
|
|
init_obs=init_obs,
|
|
)
|
|
|
|
self._has_init_obs[agent_id] = True
|
|
|
|
def add_action_reward_done_next_obs(
|
|
self,
|
|
agent_id: AgentID,
|
|
values: Dict[str, TensorType],
|
|
) -> None:
|
|
"""Add action, reward, info, and next_obs as a new step.
|
|
|
|
Args:
|
|
agent_id: Agent ID.
|
|
values: Dict of action, reward, info, and next_obs.
|
|
"""
|
|
# Make sure, agent already has some (at least init) data.
|
|
assert agent_id in self._agent_collectors
|
|
|
|
self.active_agent_steps += 1
|
|
self.total_agent_steps += 1
|
|
|
|
# Include the current agent id for multi-agent algorithms.
|
|
if agent_id != _DUMMY_AGENT_ID:
|
|
values["agent_id"] = agent_id
|
|
|
|
# Add action/reward/next-obs (and other data) to Trajectory.
|
|
self._agent_collectors[agent_id].add_action_reward_next_obs(values)
|
|
|
|
# Keep track of agent reward history.
|
|
reward = values[SampleBatch.REWARDS]
|
|
self.total_reward += reward
|
|
self.agent_rewards[(agent_id, self.policy_for(agent_id))] += reward
|
|
self._agent_reward_history[agent_id].append(reward)
|
|
|
|
# Keep track of last done info for agent.
|
|
if SampleBatch.DONES in values:
|
|
self._last_dones[agent_id] = values[SampleBatch.DONES]
|
|
# Keep track of last info dict if available.
|
|
if SampleBatch.INFOS in values:
|
|
self.set_last_info(agent_id, values[SampleBatch.INFOS])
|
|
|
|
def postprocess_episode(
|
|
self,
|
|
batch_builder: _PolicyCollectorGroup,
|
|
is_done: bool = False,
|
|
check_dones: bool = False,
|
|
) -> None:
|
|
"""Build and return currently collected training samples by policies.
|
|
|
|
Clear agent collector states if this episode is done.
|
|
|
|
Args:
|
|
batch_builder: _PolicyCollectorGroup for saving the collected per-agent
|
|
sample batches.
|
|
is_done: If this episode is done.
|
|
check_dones: Whether to make sure per-agent trajectories are actually done.
|
|
"""
|
|
# TODO: (sven) Once we implement multi-agent communication channels,
|
|
# we have to resolve the restriction of only sending other agent
|
|
# batches from the same policy to the postprocess methods.
|
|
# Build SampleBatches for the given episode.
|
|
pre_batches = {}
|
|
for agent_id, collector in self._agent_collectors.items():
|
|
# Build only if there is data and agent is part of given episode.
|
|
if collector.agent_steps == 0:
|
|
continue
|
|
pid = self.policy_for(agent_id)
|
|
policy = self.policy_map[pid]
|
|
pre_batch = collector.build_for_training(policy.view_requirements)
|
|
pre_batches[agent_id] = (pid, policy, pre_batch)
|
|
|
|
for agent_id, (pid, policy, pre_batch) in pre_batches.items():
|
|
# Entire episode is said to be done.
|
|
# Error if no DONE at end of this agent's trajectory.
|
|
if is_done and check_dones and not pre_batch[SampleBatch.DONES][-1]:
|
|
raise ValueError(
|
|
"Episode {} terminated for all agents, but we still "
|
|
"don't have a last observation for agent {} (policy "
|
|
"{}). ".format(self.episode_id, agent_id, self.policy_for(agent_id))
|
|
+ "Please ensure that you include the last observations "
|
|
"of all live agents when setting done[__all__] to "
|
|
"True. Alternatively, set no_done_at_end=True to "
|
|
"allow this."
|
|
)
|
|
|
|
# Skip a trajectory's postprocessing (and thus using it for training),
|
|
# if its agent's info exists and contains the training_enabled=False
|
|
# setting (used by our PolicyClients).
|
|
if not self._last_infos.get(agent_id, {}).get("training_enabled", True):
|
|
continue
|
|
|
|
if (
|
|
any(pre_batch[SampleBatch.DONES][:-1])
|
|
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
|
|
):
|
|
raise ValueError(
|
|
"Batches sent to postprocessing must only contain steps "
|
|
"from a single trajectory.",
|
|
pre_batch,
|
|
)
|
|
|
|
if len(pre_batches) > 1:
|
|
other_batches = pre_batches.copy()
|
|
del other_batches[agent_id]
|
|
else:
|
|
other_batches = {}
|
|
|
|
# Call the Policy's Exploration's postprocess method.
|
|
post_batch = pre_batch
|
|
if getattr(policy, "exploration", None) is not None:
|
|
policy.exploration.postprocess_trajectory(
|
|
policy, post_batch, policy.get_session()
|
|
)
|
|
post_batch.set_get_interceptor(None)
|
|
post_batch = policy.postprocess_trajectory(post_batch, other_batches, self)
|
|
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|
|
|
self.callbacks.on_postprocess_trajectory(
|
|
worker=get_global_worker(),
|
|
episode=self,
|
|
agent_id=agent_id,
|
|
policy_id=pid,
|
|
policies=self.policy_map,
|
|
postprocessed_batch=post_batch,
|
|
original_batches=pre_batches,
|
|
)
|
|
|
|
# Append post_batch for return.
|
|
if pid not in batch_builder.policy_collectors:
|
|
batch_builder.policy_collectors[pid] = _PolicyCollector(policy)
|
|
batch_builder.policy_collectors[pid].add_postprocessed_batch_for_training(
|
|
post_batch, policy.view_requirements
|
|
)
|
|
|
|
batch_builder.agent_steps += self.active_agent_steps
|
|
batch_builder.env_steps += self.active_env_steps
|
|
|
|
# AgentCollector cleared.
|
|
self.active_agent_steps = 0
|
|
self.active_env_steps = 0
|
|
|
|
def has_init_obs(self, agent_id: AgentID) -> bool:
|
|
return agent_id in self._has_init_obs and self._has_init_obs[agent_id]
|
|
|
|
def is_done(self, agent_id: AgentID) -> bool:
|
|
return self._last_dones.get(agent_id, False)
|
|
|
|
def set_last_info(self, agent_id: AgentID, info: Dict):
|
|
self._last_infos[agent_id] = info
|
|
|
|
@property
|
|
def length(self):
|
|
return self.total_env_steps
|