2019-01-23 21:27:26 -08:00
|
|
|
import collections
|
2019-03-26 00:27:59 -07:00
|
|
|
import logging
|
2019-01-23 21:27:26 -08:00
|
|
|
import numpy as np
|
2020-06-19 13:09:05 -07:00
|
|
|
from typing import List, Any, Dict, Optional, TYPE_CHECKING
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2021-03-23 17:50:18 +01:00
|
|
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
2020-06-19 13:09:05 -07:00
|
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
2021-08-03 18:30:02 -04:00
|
|
|
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.rllib.utils.debug import summarize
|
2021-03-23 17:50:18 +01:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import PolicyID, AgentID
|
2020-05-30 22:48:34 +02:00
|
|
|
from ray.util.debug import log_once
|
2019-03-26 00:27:59 -07:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def to_float_array(v: List[Any]) -> np.ndarray:
|
2019-01-23 21:27:26 -08:00
|
|
|
arr = np.array(v)
|
|
|
|
if arr.dtype == np.float64:
|
|
|
|
return arr.astype(np.float32) # save some memory
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
2021-08-03 18:30:02 -04:00
|
|
|
@Deprecated(new="a child class of `SampleCollector`", error=False)
|
2020-01-02 17:42:13 -08:00
|
|
|
class SampleBatchBuilder:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Util to build a SampleBatch incrementally.
|
|
|
|
|
|
|
|
For efficiency, SampleBatches hold values in column form (as arrays).
|
|
|
|
However, it is useful to add data one row (dict) at a time.
|
|
|
|
"""
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
def __init__(self):
|
2020-06-19 13:09:05 -07:00
|
|
|
self.buffers: Dict[str, List] = collections.defaultdict(list)
|
2019-01-23 21:27:26 -08:00
|
|
|
self.count = 0
|
|
|
|
|
2020-08-15 15:09:00 +02:00
|
|
|
def add_values(self, **values: Any) -> None:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Add the given dictionary (row) of values to this batch."""
|
|
|
|
|
|
|
|
for k, v in values.items():
|
|
|
|
self.buffers[k].append(v)
|
|
|
|
self.count += 1
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def add_batch(self, batch: SampleBatch) -> None:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Add the given batch of values to this batch."""
|
|
|
|
|
|
|
|
for k, column in batch.items():
|
|
|
|
self.buffers[k].extend(column)
|
|
|
|
self.count += batch.count
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def build_and_reset(self) -> SampleBatch:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Returns a sample batch including all previously added values."""
|
|
|
|
|
|
|
|
batch = SampleBatch(
|
|
|
|
{k: to_float_array(v)
|
|
|
|
for k, v in self.buffers.items()})
|
2021-03-17 08:18:15 +01:00
|
|
|
if SampleBatch.UNROLL_ID not in batch:
|
|
|
|
batch[SampleBatch.UNROLL_ID] = np.repeat(
|
2020-06-12 20:17:27 -07:00
|
|
|
SampleBatchBuilder._next_unroll_id, batch.count)
|
|
|
|
SampleBatchBuilder._next_unroll_id += 1
|
2019-01-23 21:27:26 -08:00
|
|
|
self.buffers.clear()
|
|
|
|
self.count = 0
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
2021-03-23 17:50:18 +01:00
|
|
|
# Deprecated class: Use a child class of `SampleCollector` instead
|
|
|
|
# (which handles multi-agent setups as well).
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class MultiAgentSampleBatchBuilder:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Util to build SampleBatches for each policy in a multi-agent env.
|
|
|
|
|
|
|
|
Input data is per-agent, while output data is per-policy. There is an M:N
|
|
|
|
mapping between agents and policies. We retain one local batch builder
|
|
|
|
per agent. When an agent is done, then its local batch is appended into the
|
|
|
|
corresponding policy batch for the agent's policy.
|
|
|
|
"""
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def __init__(self, policy_map: Dict[PolicyID, Policy], clip_rewards: bool,
|
|
|
|
callbacks: "DefaultCallbacks"):
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Initialize a MultiAgentSampleBatchBuilder.
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
Args:
|
|
|
|
policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
|
|
|
|
clip_rewards (Union[bool,float]): Whether to clip rewards before
|
|
|
|
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
2020-04-17 02:06:42 +03:00
|
|
|
callbacks (DefaultCallbacks): RLlib callbacks.
|
2019-01-23 21:27:26 -08:00
|
|
|
"""
|
2021-05-10 16:10:44 +02:00
|
|
|
if log_once("MultiAgentSampleBatchBuilder"):
|
|
|
|
deprecation_warning(
|
|
|
|
old="MultiAgentSampleBatchBuilder", error=False)
|
2019-01-23 21:27:26 -08:00
|
|
|
self.policy_map = policy_map
|
|
|
|
self.clip_rewards = clip_rewards
|
2020-06-04 22:47:32 +02:00
|
|
|
# Build the Policies' SampleBatchBuilders.
|
2019-01-23 21:27:26 -08:00
|
|
|
self.policy_builders = {
|
|
|
|
k: SampleBatchBuilder()
|
|
|
|
for k in policy_map.keys()
|
|
|
|
}
|
2020-06-04 22:47:32 +02:00
|
|
|
# Whenever we observe a new agent, add a new SampleBatchBuilder for
|
|
|
|
# this agent.
|
2019-01-23 21:27:26 -08:00
|
|
|
self.agent_builders = {}
|
2020-06-04 22:47:32 +02:00
|
|
|
# Internal agent-to-policy map.
|
2019-01-23 21:27:26 -08:00
|
|
|
self.agent_to_policy = {}
|
2020-04-17 02:06:42 +03:00
|
|
|
self.callbacks = callbacks
|
2020-06-04 22:47:32 +02:00
|
|
|
# Number of "inference" steps taken in the environment.
|
|
|
|
# Regardless of the number of agents involved in each of these steps.
|
|
|
|
self.count = 0
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def total(self) -> int:
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Returns the total number of steps taken in the env (all agents).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The number of steps taken in total in the environment over all
|
|
|
|
agents.
|
|
|
|
"""
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2020-03-04 12:58:34 -08:00
|
|
|
return sum(a.count for a in self.agent_builders.values())
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def has_pending_agent_data(self) -> bool:
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Returns whether there is pending unprocessed data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if there is at least one per-agent builder (with data
|
|
|
|
in it).
|
|
|
|
"""
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
return len(self.agent_builders) > 0
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def add_values(self, agent_id: AgentID, policy_id: AgentID,
|
2020-08-15 15:09:00 +02:00
|
|
|
**values: Any) -> None:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Add the given dictionary (row) of values to this batch.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2019-01-23 21:27:26 -08:00
|
|
|
agent_id (obj): Unique id for the agent we are adding values for.
|
|
|
|
policy_id (obj): Unique id for policy controlling the agent.
|
|
|
|
values (dict): Row of values to add for this agent.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if agent_id not in self.agent_builders:
|
|
|
|
self.agent_builders[agent_id] = SampleBatchBuilder()
|
|
|
|
self.agent_to_policy[agent_id] = policy_id
|
2020-06-12 20:17:27 -07:00
|
|
|
|
|
|
|
# Include the current agent id for multi-agent algorithms.
|
|
|
|
if agent_id != _DUMMY_AGENT_ID:
|
|
|
|
values["agent_id"] = agent_id
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
self.agent_builders[agent_id].add_values(**values)
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def postprocess_batch_so_far(
|
|
|
|
self, episode: Optional[MultiAgentEpisode] = None) -> None:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Apply policy postprocessors to any unprocessed rows.
|
|
|
|
|
|
|
|
This pushes the postprocessed per-agent batches onto the per-policy
|
|
|
|
builders, clearing per-agent state.
|
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
Args:
|
2020-06-04 22:47:32 +02:00
|
|
|
episode (Optional[MultiAgentEpisode]): The Episode object that
|
|
|
|
holds this MultiAgentBatchBuilder object.
|
2019-01-23 21:27:26 -08:00
|
|
|
"""
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
# Materialize the batches so far.
|
2019-01-23 21:27:26 -08:00
|
|
|
pre_batches = {}
|
|
|
|
for agent_id, builder in self.agent_builders.items():
|
|
|
|
pre_batches[agent_id] = (
|
|
|
|
self.policy_map[self.agent_to_policy[agent_id]],
|
|
|
|
builder.build_and_reset())
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
# Apply postprocessor.
|
2019-01-23 21:27:26 -08:00
|
|
|
post_batches = {}
|
2020-06-04 22:47:32 +02:00
|
|
|
if self.clip_rewards is True:
|
2019-01-23 21:27:26 -08:00
|
|
|
for _, (_, pre_batch) in pre_batches.items():
|
|
|
|
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
2020-06-04 22:47:32 +02:00
|
|
|
elif self.clip_rewards:
|
|
|
|
for _, (_, pre_batch) in pre_batches.items():
|
|
|
|
pre_batch["rewards"] = np.clip(
|
|
|
|
pre_batch["rewards"],
|
|
|
|
a_min=-self.clip_rewards,
|
|
|
|
a_max=self.clip_rewards)
|
2019-01-23 21:27:26 -08:00
|
|
|
for agent_id, (_, pre_batch) in pre_batches.items():
|
|
|
|
other_batches = pre_batches.copy()
|
|
|
|
del other_batches[agent_id]
|
|
|
|
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
|
|
|
if any(pre_batch["dones"][:-1]) or len(set(
|
|
|
|
pre_batch["eps_id"])) > 1:
|
|
|
|
raise ValueError(
|
|
|
|
"Batches sent to postprocessing must only contain steps "
|
|
|
|
"from a single trajectory.", pre_batch)
|
2020-03-29 00:16:30 +01:00
|
|
|
# Call the Policy's Exploration's postprocess method.
|
2020-08-19 17:49:50 +02:00
|
|
|
post_batches[agent_id] = pre_batch
|
2020-04-27 23:19:26 +02:00
|
|
|
if getattr(policy, "exploration", None) is not None:
|
|
|
|
policy.exploration.postprocess_trajectory(
|
2021-07-19 13:16:03 -04:00
|
|
|
policy, post_batches[agent_id], policy.get_session())
|
2020-08-19 17:49:50 +02:00
|
|
|
post_batches[agent_id] = policy.postprocess_trajectory(
|
|
|
|
post_batches[agent_id], other_batches, episode)
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("after_post"):
|
|
|
|
logger.info(
|
|
|
|
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
|
|
|
|
format(summarize(post_batches)))
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
# Append into policy batches and reset
|
2020-04-17 02:06:42 +03:00
|
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
2019-01-23 21:27:26 -08:00
|
|
|
for agent_id, post_batch in sorted(post_batches.items()):
|
2020-04-17 02:06:42 +03:00
|
|
|
self.callbacks.on_postprocess_trajectory(
|
|
|
|
worker=get_global_worker(),
|
|
|
|
episode=episode,
|
|
|
|
agent_id=agent_id,
|
|
|
|
policy_id=self.agent_to_policy[agent_id],
|
|
|
|
policies=self.policy_map,
|
|
|
|
postprocessed_batch=post_batch,
|
|
|
|
original_batches=pre_batches)
|
2020-04-04 16:08:51 -07:00
|
|
|
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
|
|
|
post_batch)
|
2019-03-26 00:27:59 -07:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
self.agent_builders.clear()
|
|
|
|
self.agent_to_policy.clear()
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def check_missing_dones(self) -> None:
|
2019-02-23 21:23:40 -08:00
|
|
|
for agent_id, builder in self.agent_builders.items():
|
|
|
|
if builder.buffers["dones"][-1] is not True:
|
|
|
|
raise ValueError(
|
|
|
|
"The environment terminated for all agents, but we still "
|
|
|
|
"don't have a last observation for "
|
|
|
|
"agent {} (policy {}). ".format(
|
|
|
|
agent_id, self.agent_to_policy[agent_id]) +
|
|
|
|
"Please ensure that you include the last observations "
|
2019-08-01 23:37:36 -07:00
|
|
|
"of all live agents when setting '__all__' done to True. "
|
|
|
|
"Alternatively, set no_done_at_end=True to allow this.")
|
2019-02-23 21:23:40 -08:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def build_and_reset(self, episode: Optional[MultiAgentEpisode] = None
|
|
|
|
) -> MultiAgentBatch:
|
2019-01-23 21:27:26 -08:00
|
|
|
"""Returns the accumulated sample batches for each policy.
|
|
|
|
|
|
|
|
Any unprocessed rows will be first postprocessed with a policy
|
|
|
|
postprocessor. The internal state of this builder will be reset.
|
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
Args:
|
2020-06-04 22:47:32 +02:00
|
|
|
episode (Optional[MultiAgentEpisode]): The Episode object that
|
|
|
|
holds this MultiAgentBatchBuilder object or None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
MultiAgentBatch: Returns the accumulated sample batches for each
|
|
|
|
policy.
|
2019-01-23 21:27:26 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
self.postprocess_batch_so_far(episode)
|
|
|
|
policy_batches = {}
|
|
|
|
for policy_id, builder in self.policy_builders.items():
|
|
|
|
if builder.count > 0:
|
|
|
|
policy_batches[policy_id] = builder.build_and_reset()
|
|
|
|
old_count = self.count
|
|
|
|
self.count = 0
|
|
|
|
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|