ray/rllib/evaluation/episode.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

465 lines
18 KiB
Python
Raw Normal View History

from collections import defaultdict
import numpy as np
import random
import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
2020-05-27 10:21:30 +02:00
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.typing import (
SampleBatchType,
AgentID,
PolicyID,
EnvActionType,
EnvID,
EnvInfoDict,
EnvObsType,
)
from ray.util import log_once
if TYPE_CHECKING:
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.sample_batch_builder import MultiAgentSampleBatchBuilder
@DeveloperAPI
class Episode:
"""Tracks the current state of a (possibly multi-agent) episode.
Attributes:
new_batch_builder: Create a new MultiAgentSampleBatchBuilder.
add_extra_batch: Return a built MultiAgentBatch to the sampler.
batch_builder: Batch builder for the current episode.
total_reward: Summed reward across all agents in this episode.
length: Length of this episode.
episode_id: Unique id identifying this trajectory.
agent_rewards: Summed rewards broken down by agent.
custom_metrics: Dict where the you can add custom metrics.
user_data: Dict that you can use for temporary storage. E.g.
in between two custom callbacks referring to the same episode.
hist_data: Dict mapping str keys to List[float] for storage of
per-timestep float data throughout the episode.
Use case 1: Model-based rollouts in multi-agent:
A custom compute_actions() function in a policy can inspect the
current episode state and perform a number of rollouts based on the
policies and state of other agents in the environment.
Use case 2: Returning extra rollouts data.
The model rollouts can be returned back to the sampler by calling:
>>> batch = episode.new_batch_builder()
>>> for each transition:
batch.add_values(...) # see sampler for usage
>>> episode.extra_batches.add(batch.build_and_reset())
"""
def __init__(
self,
policies: PolicyMap,
policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"], PolicyID],
batch_builder_factory: Callable[[], "MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None],
env_id: EnvID,
*,
worker: Optional["RolloutWorker"] = None,
):
"""Initializes an Episode instance.
Args:
policies: The PolicyMap object (mapping PolicyIDs to Policy
objects) to use for determining, which policy is used for
which agent.
policy_mapping_fn: The mapping function mapping AgentIDs to
PolicyIDs.
batch_builder_factory:
extra_batch_callback:
env_id: The environment's ID in which this episode runs.
worker: The RolloutWorker instance, in which this episode runs.
"""
self.new_batch_builder: Callable[
[], "MultiAgentSampleBatchBuilder"
] = batch_builder_factory
self.add_extra_batch: Callable[[SampleBatchType], None] = extra_batch_callback
self.batch_builder: "MultiAgentSampleBatchBuilder" = batch_builder_factory()
self.total_reward: float = 0.0
self.length: int = 0
self.episode_id: int = random.randrange(2e9)
self.env_id = env_id
self.worker = worker
self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
self.custom_metrics: Dict[str, float] = {}
self.user_data: Dict[str, Any] = {}
self.hist_data: Dict[str, List[float]] = {}
self.media: Dict[str, Any] = {}
self.policy_map: PolicyMap = policies
self._policies = self.policy_map # backward compatibility
self.policy_mapping_fn: Callable[
[AgentID, "Episode", "RolloutWorker"], PolicyID
] = policy_mapping_fn
self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {}
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_done: Dict[AgentID, bool] = {}
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list)
@DeveloperAPI
def soft_reset(self) -> None:
"""Clears rewards and metrics, but retains RNN and other state.
This is used to carry state across multiple logical episodes in the
same env (i.e., if `soft_horizon` is set).
"""
self.length = 0
self.episode_id = random.randrange(2e9)
self.total_reward = 0.0
self.agent_rewards = defaultdict(float)
self._agent_reward_history = defaultdict(list)
@DeveloperAPI
def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
"""Returns and stores the policy ID for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the entire episode (even if the
policy_mapping_fn is changed in the meantime!).
Args:
agent_id: The agent ID to lookup the policy ID for.
Returns:
The policy ID for the specified agent.
"""
# Perform a new policy_mapping_fn lookup and bind AgentID for the
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy:
# Try new API: pass in agent_id and episode as named args.
# New signature should be: (agent_id, episode, worker, **kwargs)
try:
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
agent_id, self, worker=self.worker
)
except TypeError as e:
if (
"positional argument" in e.args[0]
or "unexpected keyword argument" in e.args[0]
):
if log_once("policy_mapping_new_signature"):
deprecation_warning(
old="policy_mapping_fn(agent_id)",
new="policy_mapping_fn(agent_id, episode, "
"worker, **kwargs)",
)
policy_id = self._agent_to_policy[
agent_id
] = self.policy_mapping_fn(agent_id)
else:
raise e
# Use already determined PolicyID.
else:
policy_id = self._agent_to_policy[agent_id]
# PolicyID not found in policy map -> Error.
if policy_id not in self.policy_map:
raise KeyError(
"policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
)
return policy_id
@DeveloperAPI
def last_observation_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvObsType]:
"""Returns the last observation for the specified AgentID.
Args:
agent_id: The agent's ID to get the last observation for.
Returns:
Last observation the specified AgentID has seen. None in case
the agent has never made any observations in the episode.
"""
return self._agent_to_last_obs.get(agent_id)
@DeveloperAPI
def last_raw_obs_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvObsType]:
"""Returns the last un-preprocessed obs for the specified AgentID.
Args:
agent_id: The agent's ID to get the last un-preprocessed
observation for.
Returns:
Last un-preprocessed observation the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_raw_obs.get(agent_id)
@DeveloperAPI
def last_info_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvInfoDict]:
"""Returns the last info for the specified AgentID.
Args:
agent_id: The agent's ID to get the last info for.
Returns:
Last info dict the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_info.get(agent_id)
@DeveloperAPI
def last_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the last action for the specified AgentID, or zeros.
The "last" action is the most recent one taken by the agent.
Args:
agent_id: The agent's ID to get the last action for.
Returns:
Last action the specified AgentID has executed.
Zeros in case the agent has never performed any actions in the
episode.
"""
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
# Agent has already taken at least one action in the episode.
if agent_id in self._agent_to_last_action:
if policy.config.get("_disable_action_flattening"):
return self._agent_to_last_action[agent_id]
else:
return flatten_to_single_ndarray(self._agent_to_last_action[agent_id])
# Agent has not acted yet, return all zeros.
else:
if policy.config.get("_disable_action_flattening"):
return tree.map_structure(
lambda s: np.zeros_like(s.sample(), s.dtype)
if hasattr(s, "dtype")
else np.zeros_like(s.sample()),
policy.action_space_struct,
)
else:
flat = flatten_to_single_ndarray(policy.action_space.sample())
if hasattr(policy.action_space, "dtype"):
return np.zeros_like(flat, dtype=policy.action_space.dtype)
return np.zeros_like(flat)
@DeveloperAPI
def prev_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the previous action for the specified agent, or zeros.
The "previous" action is the one taken one timestep before the
most recent action taken by the agent.
Args:
agent_id: The agent's ID to get the previous action for.
Returns:
Previous action the specified AgentID has executed.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
# We are at t > 1 -> There has been a previous action by this agent.
if agent_id in self._agent_to_prev_action:
if policy.config.get("_disable_action_flattening"):
return self._agent_to_prev_action[agent_id]
else:
return flatten_to_single_ndarray(self._agent_to_prev_action[agent_id])
# We're at t <= 1, so return all zeros.
else:
if policy.config.get("_disable_action_flattening"):
return tree.map_structure(
lambda a: np.zeros_like(a, a.dtype)
if hasattr(a, "dtype") # noqa
else np.zeros_like(a), # noqa
self.last_action_for(agent_id),
)
else:
return np.zeros_like(self.last_action_for(agent_id))
@DeveloperAPI
def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the last reward for the specified agent, or zero.
The "last" reward is the one received most recently by the agent.
Args:
agent_id: The agent's ID to get the last reward for.
Returns:
Last reward for the the specified AgentID.
Zero in case the agent has never performed any actions
(and thus received rewards) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 0 -> Return previously received reward.
if len(history) >= 1:
return history[-1]
# We're at t=0, so there is no previous reward, just return zero.
else:
return 0.0
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent, or zero.
The "previous" reward is the one received one timestep before the
most recently received reward of the agent.
Args:
agent_id: The agent's ID to get the previous reward for.
Returns:
Previous reward for the the specified AgentID.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 1 -> Return reward prior to most recent (last) one.
if len(history) >= 2:
return history[-2]
# We're at t <= 1, so there is no previous reward, just return zero.
else:
return 0.0
@DeveloperAPI
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
"""Returns the last RNN state for the specified agent.
Args:
agent_id: The agent's ID to get the most recent RNN state for.
Returns:
Most recent RNN state of the the specified AgentID.
"""
if agent_id not in self._agent_to_rnn_state:
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
@DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
"""Returns the last done flag for the specified AgentID.
Args:
agent_id: The agent's ID to get the last done flag for.
Returns:
Last done flag for the specified AgentID.
"""
if agent_id not in self._agent_to_last_done:
self._agent_to_last_done[agent_id] = False
return self._agent_to_last_done[agent_id]
@DeveloperAPI
def last_extra_action_outs_for(
self,
agent_id: AgentID = _DUMMY_AGENT_ID,
) -> dict:
"""Returns the last extra-action outputs for the specified agent.
This data is returned by a call to
`Policy.compute_actions_from_input_dict` as the 3rd return value
(1st return value = action; 2nd return value = RNN state outs).
Args:
agent_id: The agent's ID to get the last extra-action outs for.
Returns:
The last extra-action outs for the specified AgentID.
"""
return self._agent_to_last_extra_action_outs[agent_id]
@DeveloperAPI
def get_agents(self) -> List[AgentID]:
"""Returns list of agent IDs that have appeared in this episode.
Returns:
The list of all agent IDs that have appeared so far in this
episode.
"""
return list(self._agent_to_index.keys())
def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
for agent_id, reward in reward_dict.items():
if reward is not None:
self.agent_rewards[agent_id, self.policy_for(agent_id)] += reward
self.total_reward += reward
self._agent_reward_history[agent_id].append(reward)
def _set_rnn_state(self, agent_id, rnn_state):
self._agent_to_rnn_state[agent_id] = rnn_state
def _set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def _set_last_raw_obs(self, agent_id, obs):
self._agent_to_last_raw_obs[agent_id] = obs
def _set_last_done(self, agent_id, done):
self._agent_to_last_done[agent_id] = done
def _set_last_info(self, agent_id, info):
self._agent_to_last_info[agent_id] = info
def _set_last_action(self, agent_id, action):
if agent_id in self._agent_to_last_action:
self._agent_to_prev_action[agent_id] = self._agent_to_last_action[agent_id]
self._agent_to_last_action[agent_id] = action
def _set_last_extra_action_outs(self, agent_id, pi_info):
self._agent_to_last_extra_action_outs[agent_id] = pi_info
def _agent_index(self, agent_id):
if agent_id not in self._agent_to_index:
self._agent_to_index[agent_id] = self._next_agent_index
self._next_agent_index += 1
return self._agent_to_index[agent_id]
@property
def _policy_mapping_fn(self):
deprecation_warning(
old="Episode._policy_mapping_fn",
new="Episode.policy_mapping_fn",
error=False,
)
return self.policy_mapping_fn
@Deprecated(new="Episode.last_extra_action_outs_for", error=False)
def last_pi_info_for(self, *args, **kwargs):
return self.last_extra_action_outs_for(*args, **kwargs)
# Backward compatibility. The name Episode implies that there is
# also a (single agent?) Episode.
@Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
class MultiAgentEpisode(Episode):
pass