mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
209 lines
7.6 KiB
Python
209 lines
7.6 KiB
Python
import collections
|
|
import logging
|
|
import numpy as np
|
|
|
|
from ray.util.debug import log_once
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
|
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
|
from ray.rllib.utils.debug import summarize
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def to_float_array(v):
|
|
arr = np.array(v)
|
|
if arr.dtype == np.float64:
|
|
return arr.astype(np.float32) # save some memory
|
|
return arr
|
|
|
|
|
|
@PublicAPI
|
|
class SampleBatchBuilder:
|
|
"""Util to build a SampleBatch incrementally.
|
|
|
|
For efficiency, SampleBatches hold values in column form (as arrays).
|
|
However, it is useful to add data one row (dict) at a time.
|
|
"""
|
|
|
|
@PublicAPI
|
|
def __init__(self):
|
|
self.buffers = collections.defaultdict(list)
|
|
self.count = 0
|
|
self.unroll_id = 0 # disambiguates unrolls within a single episode
|
|
|
|
@PublicAPI
|
|
def add_values(self, **values):
|
|
"""Add the given dictionary (row) of values to this batch."""
|
|
|
|
for k, v in values.items():
|
|
self.buffers[k].append(v)
|
|
self.count += 1
|
|
|
|
@PublicAPI
|
|
def add_batch(self, batch):
|
|
"""Add the given batch of values to this batch."""
|
|
|
|
for k, column in batch.items():
|
|
self.buffers[k].extend(column)
|
|
self.count += batch.count
|
|
|
|
@PublicAPI
|
|
def build_and_reset(self):
|
|
"""Returns a sample batch including all previously added values."""
|
|
|
|
batch = SampleBatch(
|
|
{k: to_float_array(v)
|
|
for k, v in self.buffers.items()})
|
|
batch.data[SampleBatch.UNROLL_ID] = np.repeat(self.unroll_id,
|
|
batch.count)
|
|
self.buffers.clear()
|
|
self.count = 0
|
|
self.unroll_id += 1
|
|
return batch
|
|
|
|
|
|
@DeveloperAPI
|
|
class MultiAgentSampleBatchBuilder:
|
|
"""Util to build SampleBatches for each policy in a multi-agent env.
|
|
|
|
Input data is per-agent, while output data is per-policy. There is an M:N
|
|
mapping between agents and policies. We retain one local batch builder
|
|
per agent. When an agent is done, then its local batch is appended into the
|
|
corresponding policy batch for the agent's policy.
|
|
"""
|
|
|
|
def __init__(self, policy_map, clip_rewards, postp_callback):
|
|
"""Initialize a MultiAgentSampleBatchBuilder.
|
|
|
|
Arguments:
|
|
policy_map (dict): Maps policy ids to policy instances.
|
|
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
|
postp_callback: function to call on each postprocessed batch.
|
|
"""
|
|
|
|
self.policy_map = policy_map
|
|
self.clip_rewards = clip_rewards
|
|
self.policy_builders = {
|
|
k: SampleBatchBuilder()
|
|
for k in policy_map.keys()
|
|
}
|
|
self.agent_builders = {}
|
|
self.agent_to_policy = {}
|
|
self.postp_callback = postp_callback
|
|
self.count = 0 # increment this manually
|
|
|
|
def total(self):
|
|
"""Returns summed number of steps across all agent buffers."""
|
|
|
|
return sum(a.count for a in self.agent_builders.values())
|
|
|
|
def has_pending_agent_data(self):
|
|
"""Returns whether there is pending unprocessed data."""
|
|
|
|
return len(self.agent_builders) > 0
|
|
|
|
@DeveloperAPI
|
|
def add_values(self, agent_id, policy_id, **values):
|
|
"""Add the given dictionary (row) of values to this batch.
|
|
|
|
Arguments:
|
|
agent_id (obj): Unique id for the agent we are adding values for.
|
|
policy_id (obj): Unique id for policy controlling the agent.
|
|
values (dict): Row of values to add for this agent.
|
|
"""
|
|
|
|
if agent_id not in self.agent_builders:
|
|
self.agent_builders[agent_id] = SampleBatchBuilder()
|
|
self.agent_to_policy[agent_id] = policy_id
|
|
builder = self.agent_builders[agent_id]
|
|
builder.add_values(**values)
|
|
|
|
def postprocess_batch_so_far(self, episode):
|
|
"""Apply policy postprocessors to any unprocessed rows.
|
|
|
|
This pushes the postprocessed per-agent batches onto the per-policy
|
|
builders, clearing per-agent state.
|
|
|
|
Arguments:
|
|
episode: current MultiAgentEpisode object or None
|
|
"""
|
|
|
|
# Materialize the batches so far
|
|
pre_batches = {}
|
|
for agent_id, builder in self.agent_builders.items():
|
|
pre_batches[agent_id] = (
|
|
self.policy_map[self.agent_to_policy[agent_id]],
|
|
builder.build_and_reset())
|
|
|
|
# Apply postprocessor
|
|
post_batches = {}
|
|
if self.clip_rewards:
|
|
for _, (_, pre_batch) in pre_batches.items():
|
|
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
|
for agent_id, (_, pre_batch) in pre_batches.items():
|
|
other_batches = pre_batches.copy()
|
|
del other_batches[agent_id]
|
|
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
|
if any(pre_batch["dones"][:-1]) or len(set(
|
|
pre_batch["eps_id"])) > 1:
|
|
raise ValueError(
|
|
"Batches sent to postprocessing must only contain steps "
|
|
"from a single trajectory.", pre_batch)
|
|
post_batches[agent_id] = policy.postprocess_trajectory(
|
|
pre_batch, other_batches, episode)
|
|
# Call the Policy's Exploration's postprocess method.
|
|
policy.exploration.postprocess_trajectory(
|
|
policy, post_batches[agent_id], getattr(policy, "_sess", None))
|
|
|
|
if log_once("after_post"):
|
|
logger.info(
|
|
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
|
|
format(summarize(post_batches)))
|
|
|
|
# Append into policy batches and reset
|
|
for agent_id, post_batch in sorted(post_batches.items()):
|
|
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
|
post_batch)
|
|
if self.postp_callback:
|
|
self.postp_callback({
|
|
"episode": episode,
|
|
"agent_id": agent_id,
|
|
"pre_batch": pre_batches[agent_id],
|
|
"post_batch": post_batch,
|
|
"all_pre_batches": pre_batches,
|
|
})
|
|
|
|
self.agent_builders.clear()
|
|
self.agent_to_policy.clear()
|
|
|
|
def check_missing_dones(self):
|
|
for agent_id, builder in self.agent_builders.items():
|
|
if builder.buffers["dones"][-1] is not True:
|
|
raise ValueError(
|
|
"The environment terminated for all agents, but we still "
|
|
"don't have a last observation for "
|
|
"agent {} (policy {}). ".format(
|
|
agent_id, self.agent_to_policy[agent_id]) +
|
|
"Please ensure that you include the last observations "
|
|
"of all live agents when setting '__all__' done to True. "
|
|
"Alternatively, set no_done_at_end=True to allow this.")
|
|
|
|
@DeveloperAPI
|
|
def build_and_reset(self, episode):
|
|
"""Returns the accumulated sample batches for each policy.
|
|
|
|
Any unprocessed rows will be first postprocessed with a policy
|
|
postprocessor. The internal state of this builder will be reset.
|
|
|
|
Arguments:
|
|
episode: current MultiAgentEpisode object or None
|
|
"""
|
|
|
|
self.postprocess_batch_so_far(episode)
|
|
policy_batches = {}
|
|
for policy_id, builder in self.policy_builders.items():
|
|
if builder.count > 0:
|
|
policy_batches[policy_id] = builder.build_and_reset()
|
|
old_count = self.count
|
|
self.count = 0
|
|
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|