import logging import numpy as np from typing import Dict, Optional from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() logger = logging.getLogger(__name__) class _PerPolicySampleCollector: """A class for efficiently collecting samples for a single (fixed) policy. Can be used by a _MultiAgentSampleCollector for its different policies. """ def __init__(self, num_agents: Optional[int] = None, num_timesteps: Optional[int] = None, time_major: bool = True, shift_before: int = 0, shift_after: int = 0): """Initializes a _PerPolicySampleCollector object. Args: num_agents (int): The max number of agent slots to pre-allocate in the buffer. num_timesteps (int): The max number of timesteps to pre-allocate in the buffer. time_major (Optional[bool]): Whether to preallocate buffers and collect samples in time-major fashion (TxBx...). shift_before (int): The additional number of time slots to pre-allocate at the beginning of a time window (for possible underlying data column shifts, e.g. PREV_ACTIONS). shift_after (int): The additional number of time slots to pre-allocate at the end of a time window (for possible underlying data column shifts, e.g. NEXT_OBS). """ self.num_agents = num_agents or 100 self.num_timesteps = num_timesteps self.time_major = time_major # `shift_before must at least be 1 for the init obs timestep. self.shift_before = max(shift_before, 1) self.shift_after = shift_after # The offset on the agent dim to start the next SampleBatch build from. self.sample_batch_offset = 0 # The actual underlying data-buffers. self.buffers = {} self.postprocessed_agents = [False] * self.num_agents # Next agent-slot to be used by a new agent/env combination. self.agent_slot_cursor = 0 # Maps agent/episode ID/chunk-num to an agent slot. self.agent_key_to_slot = {} # Maps agent/episode ID to the last chunk-num. self.agent_key_to_chunk_num = {} # Maps agent slot number to agent keys. self.slot_to_agent_key = [None] * self.num_agents # Maps agent/episode ID/chunk-num to a time step cursor. self.agent_key_to_timestep = {} # Total timesteps taken in the env over all agents since last reset. self.timesteps_since_last_reset = 0 # Indices (T,B) to pick from the buffers for the next forward pass. self.forward_pass_indices = [[], []] self.forward_pass_size = 0 # Maps index from the forward pass batch to (agent_id, episode_id, # env_id) tuple. self.forward_pass_index_to_agent_info = {} self.agent_key_to_forward_pass_index = {} def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID, chunk_num: int, init_obs: TensorType) -> None: """Adds a single initial observation (after env.reset()) to the buffer. Args: episode_id (EpisodeID): Unique ID for the episode we are adding the initial observation for. agent_id (AgentID): Unique ID for the agent we are adding the initial observation for. env_id (EnvID): The env ID to which `init_obs` belongs. chunk_num (int): The time-chunk number (0-based). Some episodes may last for longer than self.num_timesteps and therefore have to be chopped into chunks. init_obs (TensorType): Initial observation (after env.reset()). """ agent_key = (agent_id, episode_id, chunk_num) agent_slot = self.agent_slot_cursor self.agent_key_to_slot[agent_key] = agent_slot self.agent_key_to_chunk_num[agent_key[:2]] = chunk_num self.slot_to_agent_key[agent_slot] = agent_key self._next_agent_slot() if SampleBatch.OBS not in self.buffers: self._build_buffers( single_row={ SampleBatch.OBS: init_obs, SampleBatch.EPS_ID: episode_id, SampleBatch.AGENT_INDEX: agent_id, "env_id": env_id, }) if self.time_major: self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \ init_obs else: self.buffers[SampleBatch.OBS][agent_slot, self.shift_before-1] = \ init_obs self.agent_key_to_timestep[agent_key] = self.shift_before self._add_to_next_inference_call(agent_key, env_id, agent_slot, self.shift_before - 1) def add_action_reward_next_obs( self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID, agent_done: bool, values: Dict[str, TensorType]) -> None: """Add the given dictionary (row) of values to this batch. Args: episode_id (EpisodeID): Unique ID for the episode we are adding the values for. agent_id (AgentID): Unique ID for the agent we are adding the values for. env_id (EnvID): The env ID to which the given data belongs. agent_done (bool): Whether next obs should not be used for an upcoming inference call. Default: False = next-obs should be used for upcoming inference. values (Dict[str, TensorType]): Data dict (interpreted as a single row) to be added to buffer. Must contain keys: SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. """ assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values and SampleBatch.NEXT_OBS in values and SampleBatch.DONES in values) assert SampleBatch.OBS not in values values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] del values[SampleBatch.NEXT_OBS] chunk_num = self.agent_key_to_chunk_num[(agent_id, episode_id)] agent_key = (agent_id, episode_id, chunk_num) agent_slot = self.agent_key_to_slot[agent_key] ts = self.agent_key_to_timestep[agent_key] for k, v in values.items(): if k not in self.buffers: self._build_buffers(single_row=values) if self.time_major: self.buffers[k][ts, agent_slot] = v else: self.buffers[k][agent_slot, ts] = v self.agent_key_to_timestep[agent_key] += 1 # Time-axis is "full" -> Cut-over to new chunk (only if not DONE). if self.agent_key_to_timestep[ agent_key] - self.shift_before == self.num_timesteps and \ not values[SampleBatch.DONES]: self._new_chunk_from(agent_slot, agent_key, self.agent_key_to_timestep[agent_key]) self.timesteps_since_last_reset += 1 if not agent_done: self._add_to_next_inference_call(agent_key, env_id, agent_slot, ts) def get_inference_input_dict(self, view_reqs: Dict[str, ViewRequirement] ) -> Dict[str, TensorType]: """Returns an input_dict for an (inference) forward pass. The input_dict can then be used for action computations inside a Policy via `Policy.compute_actions_from_input_dict()`. Args: view_reqs (Dict[str, ViewRequirement]): The view requirements dict to use. Returns: Dict[str, TensorType]: The input_dict to be passed into the ModelV2 for inference/training. Examples: >>> obs, r, done, info = env.step(action) >>> collector.add_action_reward_next_obs(12345, 0, "pol0", { ... "action": action, "obs": obs, "reward": r, "done": done ... }) >>> input_dict = collector.get_inference_input_dict(policy.model) >>> action = policy.compute_actions_from_input_dict(input_dict) >>> # repeat """ input_dict = {} for view_col, view_req in view_reqs.items(): # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col if data_col not in self.buffers: self._build_buffers({data_col: view_req.space.sample()}) indices = self.forward_pass_indices if self.time_major: input_dict[view_col] = self.buffers[data_col][indices] else: if isinstance(view_req.shift, (list, tuple)): time_indices = \ np.array(view_req.shift) + np.array(indices[0]) input_dict[view_col] = self.buffers[data_col][indices[1], time_indices] else: input_dict[view_col] = \ self.buffers[data_col][indices[1], indices[0]] self._reset_inference_call() return input_dict def get_postprocessing_sample_batches( self, episode: MultiAgentEpisode, view_reqs: Dict[str, ViewRequirement]) -> \ Dict[AgentID, SampleBatch]: """Returns a SampleBatch object ready for postprocessing. Args: episode (MultiAgentEpisode): The MultiAgentEpisode object to get the to-be-postprocessed SampleBatches for. view_reqs (Dict[str, ViewRequirement]): The view requirements dict to use for creating the SampleBatch from our buffers. Returns: Dict[AgentID, SampleBatch]: The sample batch objects to be passed to `Policy.postprocess_trajectory()`. """ # Loop through all agents and create a SampleBatch # (as "view"; no copying). # Construct the SampleBatch-dict. sample_batch_data = {} range_ = self.agent_slot_cursor - self.sample_batch_offset if range_ < 0: range_ = self.num_agents + range_ for i in range(range_): agent_slot = self.sample_batch_offset + i if agent_slot >= self.num_agents: agent_slot = agent_slot % self.num_agents # Do not postprocess the same slot twice. if self.postprocessed_agents[agent_slot]: continue agent_key = self.slot_to_agent_key[agent_slot] # Skip other episodes (if episode provided). if episode and agent_key[1] != episode.episode_id: continue end = self.agent_key_to_timestep[agent_key] # Do not build any empty SampleBatches. if end == self.shift_before: continue self.postprocessed_agents[agent_slot] = True assert agent_key not in sample_batch_data sample_batch_data[agent_key] = {} batch = sample_batch_data[agent_key] for view_col, view_req in view_reqs.items(): data_col = view_req.data_col or view_col # Skip columns that will only get added through postprocessing # (these may not even exist yet). if data_col not in self.buffers: continue shift = view_req.shift if data_col == SampleBatch.OBS: shift -= 1 batch[view_col] = self.buffers[data_col][ self.shift_before + shift:end + shift, agent_slot] batches = {} for agent_key, data in sample_batch_data.items(): batches[agent_key] = SampleBatch(data) return batches def get_train_sample_batch_and_reset(self, view_reqs) -> SampleBatch: """Returns the accumulated sample batche for this policy. This is usually called to collect samples for policy training. Returns: SampleBatch: Returns the accumulated sample batch for this policy. """ seq_lens_w_0s = [ self.agent_key_to_timestep[k] - self.shift_before for k in self.slot_to_agent_key if k is not None ] # We have an agent-axis buffer "rollover" (new SampleBatch will be # built from last n agent records plus first m agent records in # buffer). if self.agent_slot_cursor < self.sample_batch_offset: rollover = -(self.num_agents - self.sample_batch_offset) seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover] first_zero_len = len(seq_lens_w_0s) if seq_lens_w_0s[-1] == 0: first_zero_len = seq_lens_w_0s.index(0) # Assert that all zeros lie at the end of the seq_lens array. assert all(seq_lens_w_0s[i] == 0 for i in range(first_zero_len, len(seq_lens_w_0s))) t_start = self.shift_before t_end = t_start + self.num_timesteps # The agent_slot cursor that points to the newest agent-slot that # actually already has at least 1 timestep of data (thus it excludes # just-rolled over chunks (which only have the initial obs in them)). valid_agent_cursor = \ (self.agent_slot_cursor - (len(seq_lens_w_0s) - first_zero_len)) % self.num_agents # Construct the view dict. view = {} for view_col, view_req in view_reqs.items(): data_col = view_req.data_col or view_col assert data_col in self.buffers # For OBS, indices must be shifted by -1. shift = view_req.shift shift += 0 if data_col != SampleBatch.OBS else -1 # If agent_slot has been rolled-over to beginning, we have to copy # here. if valid_agent_cursor < self.sample_batch_offset: time_slice = self.buffers[data_col][t_start + shift:t_end + shift] one_ = time_slice[:, self.sample_batch_offset:] two_ = time_slice[:, :valid_agent_cursor] if torch and isinstance(time_slice, torch.Tensor): view[view_col] = torch.cat([one_, two_], dim=1) else: view[view_col] = np.concatenate([one_, two_], axis=1) else: view[view_col] = \ self.buffers[data_col][ t_start + shift:t_end + shift, self.sample_batch_offset:valid_agent_cursor] # Copy all still ongoing trajectories to new agent slots # (including the ones that just started (are seq_len=0)). new_chunk_args = [] for i, seq_len in enumerate(seq_lens_w_0s): if seq_len < self.num_timesteps: agent_slot = (self.sample_batch_offset + i) % self.num_agents if not self.buffers[SampleBatch. DONES][seq_len - 1 + self.shift_before][agent_slot]: agent_key = self.slot_to_agent_key[agent_slot] new_chunk_args.append( (agent_slot, agent_key, self.agent_key_to_timestep[agent_key])) # Cut out all 0 seq-lens. seq_lens = seq_lens_w_0s[:first_zero_len] batch = SampleBatch( view, _seq_lens=np.array(seq_lens), _time_major=self.time_major) # Reset everything for new data. self.postprocessed_agents = [False] * self.num_agents self.agent_key_to_slot.clear() self.agent_key_to_chunk_num.clear() self.slot_to_agent_key = [None] * self.num_agents self.agent_key_to_timestep.clear() self.timesteps_since_last_reset = 0 self.forward_pass_size = 0 self.sample_batch_offset = self.agent_slot_cursor for args in new_chunk_args: self._new_chunk_from(*args) return batch def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: """Builds the internal data buffers based on a single given row. This may be called several times in the lifetime of this instance to add new columns to the buffer. Columns in `single_row` that already exist in the buffer will be ignored. Args: single_row (Dict[str, TensorType]): A single datarow with one or more columns (str as key, np.ndarray|tensor as data) to be used as template to build the pre-allocated buffer. """ time_size = self.num_timesteps + self.shift_before + self.shift_after for col, data in single_row.items(): if col in self.buffers: continue base_shape = (time_size, self.num_agents) if self.time_major else \ (self.num_agents, time_size) # Python primitive -> np.array. if isinstance(data, (int, float, bool)): t_ = type(data) dtype = np.float32 if t_ == float else \ np.int32 if type(data) == int else np.bool_ self.buffers[col] = np.zeros(shape=base_shape, dtype=dtype) # np.ndarray, torch.Tensor, or tf.Tensor. else: shape = base_shape + data.shape dtype = data.dtype if torch and isinstance(data, torch.Tensor): self.buffers[col] = torch.zeros( *shape, dtype=dtype, device=data.device) elif tf and isinstance(data, tf.Tensor): self.buffers[col] = tf.zeros(shape=shape, dtype=dtype) else: self.buffers[col] = np.zeros(shape=shape, dtype=dtype) def _next_agent_slot(self): """Starts a new agent slot at the end of the agent-axis. Also makes sure, the new slot is not taken yet. """ self.agent_slot_cursor += 1 if self.agent_slot_cursor >= self.num_agents: self.agent_slot_cursor = 0 # Just make sure, there is space in our buffer. assert self.slot_to_agent_key[self.agent_slot_cursor] is None def _new_chunk_from(self, agent_slot, agent_key, timestep): """Creates a new time-window (chunk) given an agent. The agent may already have an unfinished episode going on (in a previous chunk). The end of that previous chunk will be copied to the beginning of the new one for proper data-shift handling (e.g. PREV_ACTIONS/REWARDS). Args: agent_slot (int): The agent to start a new chunk for (from an ongoing episode (chunk)). agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to identify an active agent in some episode. timestep (int): The timestep in the old chunk being continued. """ new_agent_slot = self.agent_slot_cursor # Increase chunk num by 1. new_agent_key = agent_key[:2] + (agent_key[2] + 1, ) # Copy relevant timesteps at end of old chunk into new one. if self.time_major: for k in self.buffers.keys(): self.buffers[k][0:self.shift_before, new_agent_slot] = \ self.buffers[k][ timestep - self.shift_before:timestep, agent_slot] else: for k in self.buffers.keys(): self.buffers[k][new_agent_slot, 0:self.shift_before] = \ self.buffers[k][ agent_slot, timestep - self.shift_before:timestep] self.agent_key_to_slot[new_agent_key] = new_agent_slot self.agent_key_to_chunk_num[new_agent_key[:2]] = new_agent_key[2] self.slot_to_agent_key[new_agent_slot] = new_agent_key self._next_agent_slot() self.agent_key_to_timestep[new_agent_key] = self.shift_before def _add_to_next_inference_call(self, agent_key, env_id, agent_slot, timestep): """Registers given T and B (agent_slot) for get_inference_input_dict. Calling `get_inference_input_dict` will produce an input_dict (for Policy.compute_actions_from_input_dict) with all registered agent/time indices and then automatically reset the registry. Args: agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to identify an active agent in some episode. env_id (EnvID): The env ID of the given agent. agent_slot (int): The agent_slot to register (B axis). timestep (int): The timestep to register (T axis). """ idx = self.forward_pass_size self.forward_pass_index_to_agent_info[idx] = (agent_key[0], agent_key[1], env_id) self.agent_key_to_forward_pass_index[agent_key[:2]] = idx if self.forward_pass_size == 0: self.forward_pass_indices[0].clear() self.forward_pass_indices[1].clear() self.forward_pass_indices[0].append(timestep) self.forward_pass_indices[1].append(agent_slot) self.forward_pass_size += 1 def _reset_inference_call(self): """Resets indices for the next inference call. After calling this, new calls to `add_init_obs()` and `add_action_reward_next_obs()` will count for the next input_dict returned by `get_inference_input_dict()`. """ self.forward_pass_size = 0