ray/rllib/evaluation/collectors/agent_collector.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

557 lines
23 KiB
Python
Raw Normal View History

from copy import deepcopy
from gym.spaces import Space
import math
import numpy as np
import tree # pip install dm_tree
from typing import Any, Dict, List, Optional
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
from ray.rllib.utils.typing import (
EpisodeID,
EnvID,
TensorType,
ViewRequirementsDict,
)
from ray.util.annotations import PublicAPI
_, tf, _ = try_import_tf()
torch, _ = try_import_torch()
def _to_float_np_array(v: List[Any]) -> np.ndarray:
if torch and torch.is_tensor(v[0]):
raise ValueError
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
@PublicAPI
class AgentCollector:
"""Collects samples for one agent in one trajectory (episode).
The agent may be part of a multi-agent environment. Samples are stored in
lists including some possible automatic "shift" buffer at the beginning to
be able to save memory when storing things like NEXT_OBS, PREV_REWARDS,
etc.., which are specified using the trajectory view API.
"""
_next_unroll_id = 0 # disambiguates unrolls within a single episode
# TODO: @kourosh add different types of padding. e.g. zeros vs. same
def __init__(
self,
view_reqs: ViewRequirementsDict,
*,
max_seq_len: int = 1,
disable_action_flattening: bool = True,
intial_states: Optional[List[TensorType]] = None,
is_policy_recurrent: bool = False,
is_training: bool = True,
):
"""Initialize an AgentCollector.
Args:
view_reqs: A dict of view requirements for the agent.
max_seq_len: The maximum sequence length to store.
disable_action_flattening: If True, don't flatten the action.
intial_states: The initial states from the policy.get_initial_states()
is_policy_recurrent: If True, the policy is recurrent.
is_training: Sets the is_training flag for the buffers. if True, all the
timesteps are stored in the buffers until explictly build_for_training
() is called. if False, only the content required for the last time
step is stored in the buffers. This will save memory during inference.
You can change the behavior at runtime by calling is_training(mode).
"""
self.max_seq_len = max_seq_len
self.disable_action_flattening = disable_action_flattening
self.view_requirements = view_reqs
self.intial_states = intial_states or []
self.is_policy_recurrent = is_policy_recurrent
self._is_training = is_training
# Determine the size of the buffer we need for data before the actual
# episode starts. This is used for 0-buffering of e.g. prev-actions,
# or internal state inputs.
view_req_shifts = [
min(vr.shift_arr) - int((vr.data_col or k) == SampleBatch.OBS)
for k, vr in view_reqs.items()
]
self.shift_before = -min(view_req_shifts)
# The actual data buffers. Keys are column names, values are lists
# that contain the sub-components (e.g. for complex obs spaces) with
# each sub-component holding a list of per-timestep tensors.
# E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
# buffers["obs"] = [
# [0, 1], # <- 1st sub-component of observation
# [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
# ]
# NOTE: infos and state_out_... are not flattened due to them often
# using custom dict values whose structure may vary from timestep to
# timestep.
self.buffers: Dict[str, List[List[TensorType]]] = {}
# Maps column names to an example data item, which may be deeply
# nested. These are used such that we'll know how to unflatten
# the flattened data inside self.buffers when building the
# SampleBatch.
self.buffer_structs: Dict[str, Any] = {}
# The episode ID for the agent for which we collect data.
self.episode_id = None
# The unroll ID, unique across all rollouts (within a RolloutWorker).
self.unroll_id = None
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
self.agent_steps = 0
@property
def training(self) -> bool:
return self._is_training
def is_training(self, is_training: bool) -> None:
self._is_training = is_training
def is_empty(self) -> bool:
"""Returns True if this collector has no data."""
return not self.buffers or all(len(item) == 0 for item in self.buffers.values())
def add_init_obs(
self,
episode_id: EpisodeID,
agent_index: int,
env_id: EnvID,
t: int,
init_obs: TensorType,
) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Args:
episode_id: Unique ID for the episode we are adding the
initial observation for.
agent_index: Unique int index (starting from 0) for the agent
within its episode. Not to be confused with AGENT_ID (Any).
env_id: The environment index (in a vectorized setup).
t: The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs: The initial observation tensor (after
`env.reset()`).
"""
# Store episode ID + unroll ID, which will be constant throughout this
# AgentCollector's lifecycle.
self.episode_id = episode_id
if self.unroll_id is None:
self.unroll_id = AgentCollector._next_unroll_id
AgentCollector._next_unroll_id += 1
if SampleBatch.OBS not in self.buffers:
self._build_buffers(
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.AGENT_INDEX: agent_index,
SampleBatch.ENV_ID: env_id,
SampleBatch.T: t,
SampleBatch.EPS_ID: self.episode_id,
SampleBatch.UNROLL_ID: self.unroll_id,
}
)
# Append data to existing buffers.
flattened = tree.flatten(init_obs)
for i, sub_obs in enumerate(flattened):
self.buffers[SampleBatch.OBS][i].append(sub_obs)
self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
self.buffers[SampleBatch.ENV_ID][0].append(env_id)
self.buffers[SampleBatch.T][0].append(t)
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> None:
"""Adds the given dictionary (row) of values to the Agent's trajectory.
Args:
values: Data dict (interpreted as a single row) to be added to buffer.
Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
"""
if self.unroll_id is None:
self.unroll_id = AgentCollector._next_unroll_id
AgentCollector._next_unroll_id += 1
# Next obs -> obs.
# TODO @kourosh: remove the in-place operations and get rid of this deepcopy.
values = deepcopy(input_values)
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
# Make sure EPS_ID/UNROLL_ID stay the same for this agent.
if SampleBatch.EPS_ID in values:
assert values[SampleBatch.EPS_ID] == self.episode_id
del values[SampleBatch.EPS_ID]
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
if SampleBatch.UNROLL_ID in values:
assert values[SampleBatch.UNROLL_ID] == self.unroll_id
del values[SampleBatch.UNROLL_ID]
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
for k, v in values.items():
if k not in self.buffers:
self._build_buffers(single_row=values)
# Do not flatten infos, state_out_ and (if configured) actions.
# Infos/state-outs may be structs that change from timestep to
# timestep.
if (
k == SampleBatch.INFOS
or k.startswith("state_out_")
or (k == SampleBatch.ACTIONS and not self.disable_action_flattening)
):
self.buffers[k][0].append(v)
# Flatten all other columns.
else:
flattened = tree.flatten(v)
for i, sub_list in enumerate(self.buffers[k]):
sub_list.append(flattened[i])
# In inference mode, we don't need to keep all of trajectory in memory
# we only need to keep the steps required. We can pop from the beginning to
# create room for new data.
if not self.training:
for k in self.buffers:
for sub_list in self.buffers[k]:
if sub_list:
sub_list.pop(0)
self.agent_steps += 1
def build_for_inference(self) -> SampleBatch:
"""During inference, we will build a SampleBatch with a batch size of 1 that
can then be used to run the forward pass of a policy. This data will only
include the enviornment context for running the policy at the last timestep.
Returns:
A SampleBatch with a batch size of 1.
"""
batch_data = {}
np_data = {}
for view_col, view_req in self.view_requirements.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
# if this view is not for inference, skip it.
if not view_req.used_for_compute_actions:
continue
if np.any(view_req.shift_arr > 0):
raise ValueError(
f"During inference the agent can only use past observations to "
f"respect causality. However, view_col = {view_col} seems to "
f"depend on future indices {view_req.shift_arr}, while the "
f"used_for_compute_actions flag is set to True. Please fix the "
f"discrepancy. Hint: If you are using a custom model make sure "
f"the view_requirements are initialized properly and is point "
f"only refering to past timesteps during inference."
)
# Some columns don't exist yet
# (get created during postprocessing or depend on state_out).
if data_col not in self.buffers:
self._fill_buffer_with_initial_values(
data_col, view_req, build_for_inference=True
)
# Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col.
self._cache_in_np(np_data, data_col)
data = []
for d in np_data[data_col]:
# if shift_arr = [0] the data will be just the last time step
# (len(d) - 1), if shift_arr = [-1] the data will be just the timestep
# before the last one (len(d) - 2) and so on.
element_at_t = d[view_req.shift_arr + len(d) - 1]
if element_at_t.shape[0] == 1:
# squeeze to remove the T dimension if it is 1.
element_at_t = element_at_t.squeeze(0)
# add the batch dimension with [None]
data.append(element_at_t[None])
if data:
batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
batch = self._get_sample_batch(batch_data)
return batch
# TODO: @kouorsh we don't really need view_requirements anymore since it's already
# and attribute of the class
def build_for_training(
self, view_requirements: ViewRequirementsDict
) -> SampleBatch:
"""Builds a SampleBatch from the thus-far collected agent data.
If the episode/trajectory has no DONE=True at the end, will copy
the necessary n timesteps at the end of the trajectory back to the
beginning of the buffers and wait for new samples coming in.
SampleBatches created by this method will be ready for postprocessing
by a Policy.
Args:
view_requirements: The viewrequirements dict needed to build the
SampleBatch from the raw buffers (which may have data shifts as well as
mappings from view-col to data-col in them).
Returns:
SampleBatch: The built SampleBatch for this agent, ready to go into
postprocessing.
"""
batch_data = {}
np_data = {}
for view_col, view_req in view_requirements.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
if data_col not in self.buffers:
is_state = self._fill_buffer_with_initial_values(
data_col, view_req, build_for_inference=False
)
# we need to skip this view_col if it does not exist in the buffers and
# is not an RNN state because it could be the special keys that gets
# added by policy's postprocessing function for trianing.
if not is_state:
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
obs_shift = -1 if data_col == SampleBatch.OBS else 0
# Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col.
self._cache_in_np(np_data, data_col)
# Go throught each time-step in the buffer and construct the view
# accordingly.
data = []
for d in np_data[data_col]:
shifted_data = []
# batch_repeat_value determines how many time steps should we skip
# before we repeat indexing the data.
# Example: batch_repeat_value=10, shift_arr = [-3, -2, -1],
# shift_before = 3
# buffer = [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting_data = [[-3, -2, -1], [7, 8, 9]]
# explanation: For t=0, we output [-3, -2, -1]. We then skip 10 time
# steps ahead and get to t=10. For t=10, we output [7, 8, 9]. We skip
# 10 more time steps and get to t=20. but since t=20 is out of bound we
# stop.
# count computes the number of time steps that we need to consider.
# if batch_repeat_value = 1, this number should be the length of
# episode so far, which is len(buffer) - shift_before.
count = int(
math.ceil(
(len(d) - self.shift_before) / view_req.batch_repeat_value
)
)
for i in range(count):
# the indices for time step t
inds = (
self.shift_before
+ obs_shift
+ view_req.shift_arr
+ (i * view_req.batch_repeat_value)
)
# handle the case where the inds are out of bounds from the end.
# if during the indexing any of the indices are out of bounds, we
# need to use padding on the end to fill in the missing indices.
element_at_t = []
for index in inds:
if index < len(d):
element_at_t.append(d[index])
else:
# zero pad similar to the last element.
element_at_t.append(
tree.map_structure(np.zeros_like, d[-1])
)
element_at_t = np.stack(element_at_t)
if element_at_t.shape[0] == 1:
# squeeze to remove the T dimension if it is 1.
element_at_t = element_at_t.squeeze(0)
shifted_data.append(element_at_t)
# in some multi-agent cases shifted_data may be an empty list.
# In this case we should just create an empty array and return it.
if shifted_data:
shifted_data_np = np.stack(shifted_data, 0)
else:
shifted_data_np = np.array(shifted_data)
data.append(shifted_data_np)
if data:
batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
batch = self._get_sample_batch(batch_data)
# This trajectory is continuing -> Copy data at the end (in the size of
# self.shift_before) to the beginning of buffers and erase everything
# else.
if (
SampleBatch.DONES in self.buffers
and not self.buffers[SampleBatch.DONES][0][-1]
):
# Copy data to beginning of buffer and cut lists.
if self.shift_before > 0:
for k, data in self.buffers.items():
# Loop through
for i in range(len(data)):
self.buffers[k][i] = data[i][-self.shift_before :]
self.agent_steps = 0
# Reset our unroll_id.
self.unroll_id = None
return batch
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
"""Builds the buffers for sample collection, given an example data row.
Args:
single_row (Dict[str, TensorType]): A single row (keys=column
names) of data to base the buffers on.
"""
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (
1
if col
in [
SampleBatch.OBS,
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX,
SampleBatch.ENV_ID,
SampleBatch.T,
SampleBatch.UNROLL_ID,
]
else 0
)
# Store all data as flattened lists, except INFOS and state-out
# lists. These are monolithic items (infos is a dict that
# should not be further split, same for state-out items, which
# could be custom dicts as well).
if (
col == SampleBatch.INFOS
or col.startswith("state_out_")
or (col == SampleBatch.ACTIONS and not self.disable_action_flattening)
):
self.buffers[col] = [[data for _ in range(shift)]]
else:
self.buffers[col] = [
[v for _ in range(shift)] for v in tree.flatten(data)
]
# Store an example data struct so we know, how to unflatten
# each data col.
self.buffer_structs[col] = data
def _get_sample_batch(self, batch_data: Dict[str, TensorType]) -> SampleBatch:
"""Returns a SampleBatch from the given data dictionary. Also updates the
sequence information based on the max_seq_len."""
# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data, is_training=self.training)
# Adjust the seq-lens array depending on the incoming agent sequences.
if self.is_policy_recurrent:
seq_lens = []
max_seq_len = self.max_seq_len
count = batch.count
while count > 0:
seq_lens.append(min(count, max_seq_len))
count -= max_seq_len
batch["seq_lens"] = np.array(seq_lens)
batch.max_seq_len = max_seq_len
return batch
def _cache_in_np(self, cache_dict: Dict[str, List[np.ndarray]], key: str) -> None:
"""Caches the numpy version of the key in the buffer dict."""
if key not in cache_dict:
cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
def _unflatten_as_buffer_struct(
self, data: List[np.ndarray], key: str
) -> np.ndarray:
"""Unflattens the given to match the buffer struct format for that key."""
if key not in self.buffer_structs:
return data[0]
return tree.unflatten_as(self.buffer_structs[key], data)
def _fill_buffer_with_initial_values(
self,
data_col: str,
view_requirement: ViewRequirement,
build_for_inference: bool = False,
) -> bool:
"""Fills the buffer with the initial values for the given data column.
for dat_col starting with `state_out`, use the initial states of the policy,
but for other data columns, create a dummy value based on the view requirement
space.
Args:
data_col: The data column to fill the buffer with.
view_requirement: The view requirement for the view_col. Normally the view
requirement for the data column is used and if it does not exist for
some reason the view requirement for view column is used instead.
build_for_inference: Whether this is getting called for inference or not.
returns:
is_state: True if the data_col is an RNN state, False otherwise.
"""
try:
space = self.view_requirements[data_col].space
except KeyError:
space = view_requirement.space
# special treatment for state_out_<i>
# add them to the buffer in case they don't exist yet
is_state = True
if data_col.startswith("state_out_"):
if not self.is_policy_recurrent:
raise ValueError(
f"{data_col} is not available, because the given policy is"
f"not recurrent according to the input model_inital_states."
f"Have you forgotten to return non-empty lists in"
f"policy.get_initial_states()?"
)
state_ind = int(data_col.split("_")[-1])
self._build_buffers({data_col: self.intial_states[state_ind]})
else:
is_state = False
# only create dummy data during inference
if build_for_inference:
if isinstance(space, Space):
fill_value = get_dummy_batch_for_space(
space,
batch_size=1,
)
else:
fill_value = space
self._build_buffers({data_col: fill_value})
return is_state