[RLlib] Trajectory view API: Simple List Collector (on by default for PPO); LSTM-agnostic (#11056)

This commit is contained in:
Sven Mika 2020-10-01 16:57:10 +02:00 committed by GitHub
parent 0d93b1de93
commit 36bda8432b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 1154 additions and 1173 deletions

View file

@ -37,6 +37,9 @@ DEFAULT_CONFIG = with_common_config({
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
# Switch on Trajectory View API for A2/3C by default.
# NOTE: Only supported for PyTorch so far.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
# yapf: enable

View file

@ -1,8 +1,13 @@
import gym
from typing import Dict
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
@ -82,6 +87,42 @@ class ValueNetworkMixin:
return self.model.value_function()[0]
def view_requirements_fn(policy: Policy) -> Dict[str, ViewRequirement]:
"""Function defining the view requirements for training/postprocessing.
These go on top of the Policy's Model's own view requirements used for
the action computing forward passes.
Args:
policy (Policy): The Policy that requires the returned
ViewRequirements.
Returns:
Dict[str, ViewRequirement]: The Policy's view requirements.
"""
ret = {
# Next obs are needed for PPO postprocessing, but not in loss.
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1, used_for_training=False),
# Created during postprocessing.
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
# Needed for PPO's loss function.
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
}
# If policy is recurrent, have to add state_out for PPO postprocessing
# (calculating GAE from next-obs and last state-out).
if policy.is_recurrent():
init_state = policy.get_initial_state()
for i, s in enumerate(init_state):
ret["state_out_{}".format(i)] = ViewRequirement(
space=gym.spaces.Box(-1.0, 1.0, shape=(s.shape[0], )),
used_for_training=False)
return ret
A3CTorchPolicy = build_torch_policy(
name="A3CTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
@ -91,4 +132,6 @@ A3CTorchPolicy = build_torch_policy(
extra_action_out_fn=model_value_predictions,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
mixins=[ValueNetworkMixin])
mixins=[ValueNetworkMixin],
view_requirements_fn=view_requirements_fn,
)

View file

@ -8,6 +8,7 @@ See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
"""
import logging
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
@ -17,6 +18,8 @@ from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
@ -27,6 +30,9 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
# Switch on Trajectory View API for PG by default.
# NOTE: Only supported for PyTorch so far.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__

View file

@ -5,6 +5,7 @@ PyTorch policy class used for PG.
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import view_requirements_fn
from ray.rllib.agents.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
@ -77,4 +78,6 @@ PGTorchPolicy = build_torch_policy(
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages)
postprocess_fn=post_process_advantages,
view_requirements_fn=view_requirements_fn,
)

View file

@ -72,6 +72,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
"truncate_episodes": True,
# This is auto set based on sample batch size.
"train_batch_size": -1,
# Trajectory View API not supported yet for DD-PPO.
"_use_trajectory_view_api": False,
},
_allow_unknown_configs=True,
)

View file

@ -89,6 +89,9 @@ DEFAULT_CONFIG = with_common_config({
# Whether to fake GPUs (using CPUs).
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
"_fake_gpus": False,
# Switch on Trajectory View API for PPO by default.
# NOTE: Only supported for PyTorch so far.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
@ -126,7 +129,8 @@ def validate_config(config: TrainerConfigDict) -> None:
if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]:
raise ValueError(
"Episode truncation is not supported without a value "
"function. Consider setting batch_mode=complete_episodes.")
"function (to estimate the return at the end of the truncated "
"trajectory). Consider setting batch_mode=complete_episodes.")
# Multi-gpu not supported for PyTorch and tf-eager.
if config["framework"] in ["tf2", "tfe", "torch"]:

View file

@ -7,7 +7,8 @@ import numpy as np
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping, \
view_requirements_fn
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
setup_config
from ray.rllib.evaluation.postprocessing import Postprocessing
@ -18,7 +19,6 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
explained_variance, sequence_mask
@ -255,34 +255,6 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def training_view_requirements_fn(
policy: Policy) -> Dict[str, ViewRequirement]:
"""Function defining the view requirements for training the policy.
These go on top of the Policy's Model's own view requirements used for
action computing forward passes.
Args:
policy (Policy): The Policy that requires the returned
ViewRequirements.
Returns:
Dict[str, ViewRequirement]: The Policy's view requirements.
"""
return {
# Next obs are needed for PPO postprocessing.
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
# VF preds are needed for the loss.
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
# Needed for postprocessing.
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
# Created during postprocessing.
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
}
# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
PPOTorchPolicy = build_torch_policy(
@ -299,5 +271,5 @@ PPOTorchPolicy = build_torch_policy(
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
],
training_view_requirements_fn=training_view_requirements_fn,
view_requirements_fn=view_requirements_fn,
)

View file

@ -1047,9 +1047,10 @@ class Trainer(Trainable):
def _validate_config(config: PartialTrainerConfigDict):
if config.get("_use_trajectory_view_api") and \
config.get("framework") != "torch":
raise ValueError(
logger.info(
"`_use_trajectory_view_api` only supported for PyTorch so "
"far!")
"far! Will run w/o.")
config["_use_trajectory_view_api"] = False
elif not config.get("_use_trajectory_view_api") and \
config.get("model", {}).get("_time_major"):
raise ValueError("`model._time_major` only supported "

View file

@ -183,7 +183,8 @@ class BaseEnv:
reset the entire Env (i.e. all sub-envs).
Returns:
obs (dict|None): Resetted observation or None if not supported.
Optional[MultiAgentDict]: Resetted (multi-agent) observation dict
or None if reset is not supported.
"""
return None

View file

View file

@ -1,9 +1,10 @@
from abc import abstractmethod, ABCMeta
import logging
from typing import Dict, Optional
from typing import Dict, Union
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.utils.typing import AgentID, EpisodeID, PolicyID, \
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ class _SampleCollector(metaclass=ABCMeta):
"""
@abstractmethod
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.
@ -41,10 +42,11 @@ class _SampleCollector(metaclass=ABCMeta):
called for that same agent/episode-pair.
Args:
episode_id (EpisodeID): Unique id for the episode we are adding
values for.
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we
are adding an Agent's initial observation.
agent_id (AgentID): Unique id for the agent we are adding
values for.
env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent.
init_obs (TensorType): Initial observation (after env.reset()).
@ -52,7 +54,7 @@ class _SampleCollector(metaclass=ABCMeta):
>>> obs = env.reset()
>>> collector.add_init_obs(12345, 0, "pol0", obs)
>>> obs, r, done, info = env.step(action)
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
... "action": action, "obs": obs, "reward": r, "done": done
... })
"""
@ -60,7 +62,8 @@ class _SampleCollector(metaclass=ABCMeta):
@abstractmethod
def add_action_reward_next_obs(self, episode_id: EpisodeID,
agent_id: AgentID, policy_id: PolicyID,
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
values: Dict[str, TensorType]) -> None:
"""Add the given dictionary (row) of values to this collector.
@ -74,7 +77,10 @@ class _SampleCollector(metaclass=ABCMeta):
values for.
agent_id (AgentID): Unique id for the agent we are adding
values for.
env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent.
agent_done (bool): Whether the given agent is done with its
trajectory (the multi-agent episode may still be ongoing).
values (Dict[str, TensorType]): Row of values to add for this
agent. This row must contain the keys SampleBatch.ACTION,
REWARD, NEW_OBS, and DONE.
@ -83,12 +89,22 @@ class _SampleCollector(metaclass=ABCMeta):
>>> obs = env.reset()
>>> collector.add_init_obs(12345, 0, "pol0", obs)
>>> obs, r, done, info = env.step(action)
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
... "action": action, "obs": obs, "reward": r, "done": done
... })
"""
raise NotImplementedError
@abstractmethod
def episode_step(self, episode_id: EpisodeID) -> None:
"""Increases the episode step counter (across all agents) by one.
Args:
episode_id (EpisodeID): Unique id for the episode we are stepping
through (across all agents in that episode).
"""
raise NotImplementedError
@abstractmethod
def total_env_steps(self) -> int:
"""Returns total number of steps taken in the env (sum of all agents).
@ -126,19 +142,11 @@ class _SampleCollector(metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def has_non_postprocessed_data(self) -> bool:
"""Returns whether there is pending, unprocessed data.
Returns:
bool: True if there is at least some data that has not been
postprocessed yet.
"""
raise NotImplementedError
@abstractmethod
def postprocess_trajectories_so_far(
self, episode: Optional[MultiAgentEpisode] = None) -> None:
"""Apply postprocessing to unprocessed data (in one or all episodes).
def postprocess_episode(self,
episode: MultiAgentEpisode,
is_done: bool = False,
check_dones: bool = False) -> None:
"""Postprocesses all agents' trajectories in a given episode.
Generates (single-trajectory) SampleBatches for all Policies/Agents and
calls Policy.postprocess_trajectory on each of these. Postprocessing
@ -148,38 +156,46 @@ class _SampleCollector(metaclass=ABCMeta):
correctly added to the buffers.
Args:
episode (Optional[MultiAgentEpisode]): The Episode object for which
to post-process data. If not provided, postprocess data for all
episodes.
episode (MultiAgentEpisode): The Episode object for which
to post-process data.
is_done (bool): Whether the given episode is actually terminated
(all agents are done).
check_dones (bool): Whether we need to check that all agents'
trajectories have dones=True at the end.
"""
raise NotImplementedError
@abstractmethod
def check_missing_dones(self, episode_id: EpisodeID) -> None:
"""Checks whether given episode is properly terminated with done=True.
This applies to all agents in the episode.
def build_multi_agent_batch(self, env_steps: int) -> \
Union[MultiAgentBatch, SampleBatch]:
"""Builds a MultiAgentBatch of size=env_steps from the collected data.
Args:
episode_id (EpisodeID): The episode ID to check for proper
termination.
env_steps (int): The sum of all env-steps (across all agents) taken
so far.
Raises:
ValueError: If `episode` has no done=True at the end.
Returns:
Union[MultiAgentBatch, SampleBatch]: Returns the accumulated
sample batches for each policy inside one MultiAgentBatch
object (or a simple SampleBatch if only one policy).
"""
raise NotImplementedError
@abstractmethod
def get_multi_agent_batch_and_reset(self):
"""Returns the accumulated sample batches for each policy.
def try_build_truncated_episode_multi_agent_batch(self) -> \
Union[MultiAgentBatch, SampleBatch, None]:
"""Tries to build an MA-batch, if `rollout_fragment_length` is reached.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset to
start the next batch.
Any unprocessed data will be first postprocessed with a policy
postprocessor.
This is usually called to collect samples for policy training.
If not enough data has been collected yet (`rollout_fragment_length`),
returns None.
Returns:
MultiAgentBatch: Returns the accumulated sample batches for each
policy inside one MultiAgentBatch object.
Union[MultiAgentBatch, SampleBatch, None]: Returns the accumulated
sample batches for each policy inside one MultiAgentBatch
object (or a simple SampleBatch if only one policy) or None
if `self.rollout_fragment_length` has not been reached yet.
"""
raise NotImplementedError

View file

@ -0,0 +1,589 @@
import collections
import logging
import numpy as np
from typing import List, Any, Dict, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
from ray.util.debug import log_once
_, tf, _ = try_import_tf()
torch, _ = try_import_torch()
if TYPE_CHECKING:
from ray.rllib.agents.callbacks import DefaultCallbacks
logger = logging.getLogger(__name__)
def to_float_np_array(v: List[Any]) -> np.ndarray:
if torch.is_tensor(v[0]):
raise ValueError
v = convert_to_non_torch_type(v)
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
class _AgentCollector:
"""Collects samples for one agent in one trajectory (episode).
The agent may be part of a multi-agent environment. Samples are stored in
lists including some possible automatic "shift" buffer at the beginning to
be able to save memory when storing things like NEXT_OBS, PREV_REWARDS,
etc.., which are specified using the trajectory view API.
"""
_next_unroll_id = 0 # disambiguates unrolls within a single episode
def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
self.buffers: Dict[str, List] = {}
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
self.count = 0
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the
initial observation for.
env_id (EnvID): The environment index (in a vectorized setup).
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
view_requirements (Dict[str, ViewRequirements])
"""
if SampleBatch.OBS not in self.buffers:
self._build_buffers(
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
"env_id": env_id,
})
self.buffers[SampleBatch.OBS].append(init_obs)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
"""Adds the given dictionary (row) of values to the Agent's trajectory.
Args:
values (Dict[str, TensorType]): Data dict (interpreted as a single
row) to be added to buffer. Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
"""
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
for k, v in values.items():
if k not in self.buffers:
self._build_buffers(single_row=values)
self.buffers[k].append(v)
self.count += 1
def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
SampleBatch:
"""Builds a SampleBatch from the thus-far collected agent data.
If the episode/trajectory has no DONE=True at the end, will copy
the necessary n timesteps at the end of the trajectory back to the
beginning of the buffers and wait for new samples coming in.
SampleBatches created by this method will be ready for postprocessing
by a Policy.
Args:
view_requirements (Dict[str, ViewRequirement]: The view
requirements dict needed to build the SampleBatch from the raw
buffers (which may have data shifts as well as mappings from
view-col to data-col in them).
Returns:
SampleBatch: The built SampleBatch for this agent, ready to go into
postprocessing.
"""
# TODO: measure performance gains when using a UsageTrackingDict
# instead of a SampleBatch for postprocessing (this would eliminate
# copies (for creating this SampleBatch) of many unused columns for
# no reason (not used by postprocessor)).
batch_data = {}
np_data = {}
for view_col, view_req in view_requirements.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
# Some columns don't exist yet (get created during postprocessing).
# -> skip.
if data_col not in self.buffers:
continue
shift = view_req.shift - \
(1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
if shift == 0:
batch_data[view_col] = np_data[data_col][self.shift_before:]
else:
batch_data[view_col] = np_data[data_col][self.shift_before +
shift:shift]
batch = SampleBatch(batch_data)
if SampleBatch.UNROLL_ID not in batch.data:
batch.data[SampleBatch.UNROLL_ID] = np.repeat(
_AgentCollector._next_unroll_id, batch.count)
_AgentCollector._next_unroll_id += 1
# This trajectory is continuing -> Copy data at the end (in the size of
# self.shift_before) to the beginning of buffers and erase everything
# else.
if not self.buffers[SampleBatch.DONES][-1]:
# Copy data to beginning of buffer and cut lists.
if self.shift_before > 0:
for k, data in self.buffers.items():
self.buffers[k] = data[-self.shift_before:]
self.count = 0
return batch
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
"""Builds the buffers for sample collection, given an example data row.
Args:
single_row (Dict[str, TensorType]): A single row (keys=column
names) of data to base the buffers on.
"""
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
# Python primitive.
if isinstance(data, (int, float, bool, str)):
self.buffers[col] = [0 for _ in range(shift)]
# np.ndarray, torch.Tensor, or tf.Tensor.
else:
shape = data.shape
dtype = data.dtype
if torch and isinstance(data, torch.Tensor):
self.buffers[col] = \
[torch.zeros(shape, dtype=dtype, device=data.device)
for _ in range(shift)]
elif tf and isinstance(data, tf.Tensor):
self.buffers[col] = \
[tf.zeros(shape=shape, dtype=dtype)
for _ in range(shift)]
else:
self.buffers[col] = \
[np.zeros(shape=shape, dtype=dtype)
for _ in range(shift)]
class _PolicyCollector:
"""Collects already postprocessed (single agent) samples for one policy.
Samples come in through already postprocessed SampleBatches, which
contain single episode/trajectory data for a single agent and are then
appended to this policy's buffers.
"""
def __init__(self):
"""Initializes a _PolicyCollector instance."""
self.buffers: Dict[str, List] = collections.defaultdict(list)
# The total timestep count for all agents that use this policy.
# NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n.
self.count = 0
def add_postprocessed_batch_for_training(
self, batch: SampleBatch,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
Args:
batch (SampleBatch): A single agent (one trajectory) SampleBatch
to be added to the Policy's buffers.
view_requirements (Dict[str, ViewRequirement]: The view
requirements for the policy. This is so we know, whether a
view-column needs to be copied at all (not needed for
training).
"""
for view_col, data in batch.items():
# Skip columns that are not used for training.
if view_col in view_requirements and \
not view_requirements[view_col].used_for_training:
continue
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count.
self.count += batch.count
def build(self):
"""Builds a SampleBatch for this policy from the collected data.
Also resets all buffers for further sample collection for this policy.
Returns:
SampleBatch: The SampleBatch with all thus-far collected data for
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
# Clear buffers for future samples.
self.buffers.clear()
# Reset count to 0.
self.count = 0
return batch
class _SimpleListCollector(_SampleCollector):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
mapping between agents and policies. We retain one local batch builder
per agent. When an agent is done, then its local batch is appended into the
corresponding policy batch for the agent's policy.
"""
def __init__(self,
policy_map: Dict[PolicyID, Policy],
clip_rewards: Union[bool, float],
callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = True,
rollout_fragment_length: int = 200):
"""Initializes a _SimpleListCollector instance.
Args:
policy_map (Dict[str, Policy]): Maps policy ids to policy
instances.
clip_rewards (Union[bool, float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.callbacks = callbacks
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.rollout_fragment_length = rollout_fragment_length
self.large_batch_threshold: int = max(
1000, rollout_fragment_length *
10) if rollout_fragment_length != float("inf") else 5000
# Build each Policies' single collector.
self.policy_collectors = {
pid: _PolicyCollector()
for pid in policy_map.keys()
}
self.policy_collectors_env_steps = 0
# Whenever we observe a new episode+agent, add a new
# _SingleTrajectoryCollector.
self.agent_collectors: Dict[Tuple[EpisodeID, AgentID],
_AgentCollector] = {}
# Internal agent-key-to-policy map.
self.agent_key_to_policy = {}
# Agents to collect data from for the next forward pass (per policy).
self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()}
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
# Maps episode ID to _EpisodeRecord objects.
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
@override(_SampleCollector)
def episode_step(self, episode_id: EpisodeID) -> None:
self.episode_steps[episode_id] += 1
env_steps = \
self.policy_collectors_env_steps + self.episode_steps[episode_id]
if (env_steps > self.large_batch_threshold
and log_once("large_batch_warning")):
logger.warning(
"More than {} observations for {} env steps ".format(
env_steps, env_steps) +
"are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some point. "
"Note: In multi-agent environments, `rollout_fragment_length` "
"sets the batch size based on (across-agents) environment "
"steps, not the steps of individual agents, which can result "
"in unexpectedly large batches." +
("Also, you may be in evaluation waiting for your Env to "
"terminate (batch_mode=`complete_episodes`). Make sure it "
"does at some point."
if not self.multiple_episodes_in_batch else ""))
@override(_SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
init_obs: TensorType) -> None:
# Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id)
if agent_key not in self.agent_key_to_policy:
self.agent_key_to_policy[agent_key] = policy_id
else:
assert self.agent_key_to_policy[agent_key] == policy_id
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \
hasattr(policy, "model") else policy.view_requirements
# Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors
# TODO: determine exact shift-before based on the view-req shifts.
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_id=agent_id,
env_id=env_id,
init_obs=init_obs,
view_requirements=view_reqs)
self.episodes[episode.episode_id] = episode
self._add_to_next_inference_call(agent_key, env_id)
@override(_SampleCollector)
def add_action_reward_next_obs(self, episode_id: EpisodeID,
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
values: Dict[str, TensorType]) -> None:
# Make sure, episode/agent already has some (at least init) data.
agent_key = (episode_id, agent_id)
assert self.agent_key_to_policy[agent_key] == policy_id
assert agent_key in self.agent_collectors
# Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id
# Add action/reward/next-obs (and other data) to Trajectory.
self.agent_collectors[agent_key].add_action_reward_next_obs(values)
if not agent_done:
self._add_to_next_inference_call(agent_key, env_id)
@override(_SampleCollector)
def total_env_steps(self) -> int:
return sum(a.count for a in self.agent_collectors.values())
@override(_SampleCollector)
def get_inference_input_dict(self, policy_id: PolicyID) -> \
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
keys = self.forward_pass_agent_keys[policy_id]
buffers = {k: self.agent_collectors[k].buffers for k in keys}
view_reqs = policy.model.inference_view_requirements if \
hasattr(policy, "model") else policy.view_requirements
input_dict = {}
for view_col, view_req in view_reqs.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0)
data_list = []
for k in keys:
if data_col not in buffers[k]:
self.agent_collectors[k]._build_buffers({
data_col: view_req.space.sample()
})
data_list.append(buffers[k][data_col][time_indices])
input_dict[view_col] = np.array(data_list)
self._reset_inference_calls(policy_id)
return input_dict
@override(_SampleCollector)
def postprocess_episode(self,
episode: MultiAgentEpisode,
is_done: bool = False,
check_dones: bool = False) -> None:
episode_id = episode.episode_id
# TODO: (sven) Once we implement multi-agent communication channels,
# we have to resolve the restriction of only sending other agent
# batches from the same policy to the postprocess methods.
# Build SampleBatches for the given episode.
pre_batches = {}
for (eps_id, agent_id), collector in self.agent_collectors.items():
# Build only if there is data and agent is part of given episode.
if collector.count == 0 or eps_id != episode_id:
continue
policy = self.policy_map[self.agent_key_to_policy[(eps_id,
agent_id)]]
pre_batch = collector.build(policy.view_requirements)
pre_batches[agent_id] = (policy, pre_batch)
# Apply postprocessor.
post_batches = {}
if self.clip_rewards is True:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
elif self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.clip(
pre_batch["rewards"],
a_min=-self.clip_rewards,
a_max=self.clip_rewards)
for agent_id, (_, pre_batch) in pre_batches.items():
# Entire episode is said to be done.
if is_done:
# Error if no DONE at end of this agent's trajectory.
if check_dones and not pre_batch[SampleBatch.DONES][-1]:
raise ValueError(
"Episode {} terminated for all agents, but we still "
"don't have a last observation for agent {} (policy "
"{}). ".format(
episode_id, agent_id, self.agent_key_to_policy[(
episode_id, agent_id)]) +
"Please ensure that you include the last observations "
"of all live agents when setting done[__all__] to "
"True. Alternatively, set no_done_at_end=True to "
"allow this.")
# If (only this?) agent is done, erase its buffer entirely.
if pre_batch[SampleBatch.DONES][-1]:
del self.agent_collectors[(episode_id, agent_id)]
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_key_to_policy[(episode_id,
agent_id)]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
# Call the Policy's Exploration's postprocess method.
post_batches[agent_id] = pre_batch
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id],
getattr(policy, "_sess", None))
post_batches[agent_id] = policy.postprocess_trajectory(
post_batches[agent_id], other_batches, episode)
if log_once("after_post"):
logger.info(
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
format(summarize(post_batches)))
# Append into policy batches and reset.
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_id, post_batch in sorted(post_batches.items()):
pid = self.agent_key_to_policy[(episode_id, agent_id)]
policy = self.policy_map[pid]
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=episode,
agent_id=agent_id,
policy_id=pid,
policies=self.policy_map,
postprocessed_batch=post_batch,
original_batches=pre_batches)
# Add the postprocessed SampleBatch to the policy collectors for
# training.
self.policy_collectors[pid].add_postprocessed_batch_for_training(
post_batch, policy.view_requirements)
env_steps = self.episode_steps[episode_id]
self.policy_collectors_env_steps += env_steps
if is_done:
del self.episode_steps[episode_id]
del self.episodes[episode_id]
else:
self.episode_steps[episode_id] = 0
@override(_SampleCollector)
def build_multi_agent_batch(self, env_steps: int) -> \
Union[MultiAgentBatch, SampleBatch]:
ma_batch = MultiAgentBatch.wrap_as_needed(
{
pid: collector.build()
for pid, collector in self.policy_collectors.items()
if collector.count > 0
},
env_steps=env_steps)
self.policy_collectors_env_steps = 0
return ma_batch
@override(_SampleCollector)
def try_build_truncated_episode_multi_agent_batch(self) -> \
Union[MultiAgentBatch, SampleBatch, None]:
# Have something to loop through, even if there are currently no
# ongoing episodes.
episode_steps = self.episode_steps or {"_fake_id": 0}
# Loop through ongoing episodes and see whether their length plus
# what's already in the policy collectors reaches the fragment-len.
for episode_id, count in episode_steps.items():
env_steps = self.policy_collectors_env_steps + count
# Reached the fragment-len -> We should build an MA-Batch.
if env_steps >= self.rollout_fragment_length:
# If we reached the fragment-len only because of `episode_id`
# (still ongoing) -> postprocess `episode_id` first.
if self.policy_collectors_env_steps < \
self.rollout_fragment_length:
self.postprocess_episode(
self.episodes[episode_id], is_done=False)
# Otherwise, create MA-batch only from what's already in our
# policy buffers (do not include `episode_id`'s data).
else:
env_steps = self.policy_collectors_env_steps
# Build the MA-batch and return.
ma_batch = self.build_multi_agent_batch(env_steps=env_steps)
return ma_batch
return None
def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID],
env_id: EnvID) -> None:
"""Adds an Agent key (episode+agent IDs) to the next inference call.
This makes sure that the agent's current data (in the trajectory) is
used for generating the next input_dict for a
`Policy.compute_actions()` call.
Args:
agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across
vectorized environments).
env_id (EnvID): The environment index (in a vectorized setup).
"""
policy_id = self.agent_key_to_policy[agent_key]
idx = self.forward_pass_size[policy_id]
if idx == 0:
self.forward_pass_agent_keys[policy_id].clear()
self.forward_pass_agent_keys[policy_id].append(agent_key)
self.forward_pass_size[policy_id] += 1
def _reset_inference_calls(self, policy_id: PolicyID) -> None:
"""Resets internal inference input-dict registries.
Calling `self.get_inference_input_dict()` after this method is called
would return an empty input-dict.
Args:
policy_id (PolicyID): The policy ID for which to reset the
inference pointers.
"""
self.forward_pass_size[policy_id] = 0

View file

@ -1,249 +0,0 @@
import logging
from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.per_policy_sample_collector import \
_PerPolicySampleCollector
from ray.rllib.evaluation.sample_collector import _SampleCollector
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.rllib.agents.callbacks import DefaultCallbacks
logger = logging.getLogger(__name__)
class _MultiAgentSampleCollector(_SampleCollector):
"""Builds SampleBatches for each policy (and agent) in a multi-agent env.
Note: This is an experimental class only used when
`config._use_trajectory_view_api` = True.
Once `_use_trajectory_view_api` becomes the default in configs:
This class will deprecate the `SampleBatchBuilder` class.
Input data is collected in central per-policy buffers, which
efficiently pre-allocate memory (over n timesteps) and re-use the same
memory even for succeeding agents and episodes.
Input_dicts for action computations, SampleBatches for postprocessing, and
train_batch dicts are - if possible - created from the central per-policy
buffers via views to avoid copying of data).
"""
def __init__(
self,
policy_map: Dict[PolicyID, Policy],
callbacks: "DefaultCallbacks",
# TODO: (sven) make `num_agents` flexibly grow in size.
num_agents: int = 100,
num_timesteps=None,
time_major: Optional[bool] = False):
"""Initializes a _MultiAgentSampleCollector object.
Args:
policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy
instances.
callbacks (DefaultCallbacks): RLlib callbacks (configured in the
Trainer config dict). Used for trajectory postprocessing event.
num_agents (int): The max number of agent slots to pre-allocate
in the buffer.
num_timesteps (int): The max number of timesteps to pre-allocate
in the buffer.
time_major (Optional[bool]): Whether to preallocate buffers and
collect samples in time-major fashion (TxBx...).
"""
self.policy_map = policy_map
self.callbacks = callbacks
if num_agents == float("inf") or num_agents is None:
num_agents = 1000
self.num_agents = int(num_agents)
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
self.policy_sample_collectors = {}
for pid, policy in policy_map.items():
# Figure out max-shifts (before and after).
view_reqs = policy.training_view_requirements
max_shift_before = 0
max_shift_after = 0
for vr in view_reqs.values():
shift = force_list(vr.shift)
if max_shift_before > shift[0]:
max_shift_before = shift[0]
if max_shift_after < shift[-1]:
max_shift_after = shift[-1]
# Figure out num_timesteps and num_agents.
kwargs = {"time_major": time_major}
if policy.is_recurrent():
kwargs["num_timesteps"] = \
policy.config["model"]["max_seq_len"]
kwargs["time_major"] = True
elif num_timesteps is not None:
kwargs["num_timesteps"] = num_timesteps
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
num_agents=self.num_agents,
shift_before=-max_shift_before,
shift_after=max_shift_after,
**kwargs)
# Internal agent-to-policy map.
self.agent_to_policy = {}
# Number of "inference" steps taken in the environment.
# Regardless of the number of agents involved in each of these steps.
self.count = 0
@override(_SampleCollector)
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
obs: TensorType) -> None:
# Make sure our mappings are up to date.
if agent_id not in self.agent_to_policy:
self.agent_to_policy[agent_id] = policy_id
else:
assert self.agent_to_policy[agent_id] == policy_id
# Add initial obs to Trajectory.
self.policy_sample_collectors[policy_id].add_init_obs(
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
@override(_SampleCollector)
def add_action_reward_next_obs(self, episode_id: EpisodeID,
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
values: Dict[str, TensorType]) -> None:
assert policy_id in self.policy_sample_collectors
# Make sure our mappings are up to date.
if agent_id not in self.agent_to_policy:
self.agent_to_policy[agent_id] = policy_id
else:
assert self.agent_to_policy[agent_id] == policy_id
# Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id
# Add action/reward/next-obs (and other data) to Trajectory.
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
episode_id, agent_id, env_id, agent_done, values)
@override(_SampleCollector)
def total_env_steps(self) -> int:
return sum(a.timesteps_since_last_reset
for a in self.policy_sample_collectors.values())
def total(self):
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
# Sampler is currently still using `total()`.
return self.total_env_steps()
@override(_SampleCollector)
def get_inference_input_dict(self, policy_id: PolicyID) -> \
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements
return self.policy_sample_collectors[
policy_id].get_inference_input_dict(view_reqs)
@override(_SampleCollector)
def has_non_postprocessed_data(self) -> bool:
return self.total_env_steps() > 0
@override(_SampleCollector)
def postprocess_trajectories_so_far(
self, episode: Optional[MultiAgentEpisode] = None) -> None:
# Loop through each per-policy collector and create a view (for each
# agent as SampleBatch) from its buffers for post-processing
all_agent_batches = {}
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
agent_batches = rc.get_postprocessing_sample_batches(
episode, view_reqs)
for agent_key, batch in agent_batches.items():
other_batches = None
if len(agent_batches) > 1:
other_batches = agent_batches.copy()
del other_batches[agent_key]
agent_batches[agent_key] = policy.postprocess_trajectory(
batch, other_batches, episode)
# Call the Policy's Exploration's postprocess method.
if getattr(policy, "exploration", None) is not None:
agent_batches[
agent_key] = policy.exploration.postprocess_trajectory(
policy, agent_batches[agent_key],
getattr(policy, "_sess", None))
# Add new columns' data to buffers.
for col in agent_batches[agent_key].new_columns:
data = agent_batches[agent_key].data[col]
rc._build_buffers({col: data[0]})
timesteps = data.shape[0]
rc.buffers[col][rc.shift_before:rc.shift_before +
timesteps, rc.agent_key_to_slot[
agent_key]] = data
all_agent_batches.update(agent_batches)
if log_once("after_post"):
logger.info("Trajectory fragment after postprocess_trajectory():"
"\n\n{}\n".format(summarize(all_agent_batches)))
# Append into policy batches and reset
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_key, batch in sorted(all_agent_batches.items()):
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=episode,
agent_id=agent_key[0],
policy_id=self.agent_to_policy[agent_key[0]],
policies=self.policy_map,
postprocessed_batch=batch,
original_batches=None) # TODO: (sven) do we really need this?
@override(_SampleCollector)
def check_missing_dones(self, episode_id: EpisodeID) -> None:
for pid, rc in self.policy_sample_collectors.items():
for agent_key in rc.agent_key_to_slot.keys():
# Only check for given episode and only for last chunk
# (all previous chunks for that agent in the episode are
# non-terminal).
if (agent_key[1] == episode_id
and rc.agent_key_to_chunk_num[agent_key[:2]] ==
agent_key[2]):
t = rc.agent_key_to_timestep[agent_key] - 1
b = rc.agent_key_to_slot[agent_key]
if not rc.buffers["dones"][t][b]:
raise ValueError(
"Episode {} terminated for all agents, but we "
"still don't have a last observation for "
"agent {} (policy {}). ".format(agent_key[0], pid)
+ "Please ensure that you include the last "
"observations of all live agents when setting "
"'__all__' done to True. Alternatively, set "
"no_done_at_end=True to allow this.")
@override(_SampleCollector)
def get_multi_agent_batch_and_reset(self):
self.postprocess_trajectories_so_far()
policy_batches = {}
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
policy_batches[pid] = rc.get_train_sample_batch_and_reset(
view_reqs)
ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count)
# Reset our across-all-agents env step count.
self.count = 0
return ma_batch

View file

@ -1,499 +0,0 @@
import logging
import numpy as np
from typing import Dict, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
class _PerPolicySampleCollector:
"""A class for efficiently collecting samples for a single (fixed) policy.
Can be used by a _MultiAgentSampleCollector for its different policies.
"""
def __init__(self,
num_agents: Optional[int] = None,
num_timesteps: Optional[int] = None,
time_major: bool = True,
shift_before: int = 0,
shift_after: int = 0):
"""Initializes a _PerPolicySampleCollector object.
Args:
num_agents (int): The max number of agent slots to pre-allocate
in the buffer.
num_timesteps (int): The max number of timesteps to pre-allocate
in the buffer.
time_major (Optional[bool]): Whether to preallocate buffers and
collect samples in time-major fashion (TxBx...).
shift_before (int): The additional number of time slots to
pre-allocate at the beginning of a time window (for possible
underlying data column shifts, e.g. PREV_ACTIONS).
shift_after (int): The additional number of time slots to
pre-allocate at the end of a time window (for possible
underlying data column shifts, e.g. NEXT_OBS).
"""
self.num_agents = num_agents or 100
self.num_timesteps = num_timesteps
self.time_major = time_major
# `shift_before must at least be 1 for the init obs timestep.
self.shift_before = max(shift_before, 1)
self.shift_after = shift_after
# The offset on the agent dim to start the next SampleBatch build from.
self.sample_batch_offset = 0
# The actual underlying data-buffers.
self.buffers = {}
self.postprocessed_agents = [False] * self.num_agents
# Next agent-slot to be used by a new agent/env combination.
self.agent_slot_cursor = 0
# Maps agent/episode ID/chunk-num to an agent slot.
self.agent_key_to_slot = {}
# Maps agent/episode ID to the last chunk-num.
self.agent_key_to_chunk_num = {}
# Maps agent slot number to agent keys.
self.slot_to_agent_key = [None] * self.num_agents
# Maps agent/episode ID/chunk-num to a time step cursor.
self.agent_key_to_timestep = {}
# Total timesteps taken in the env over all agents since last reset.
self.timesteps_since_last_reset = 0
# Indices (T,B) to pick from the buffers for the next forward pass.
self.forward_pass_indices = [[], []]
self.forward_pass_size = 0
# Maps index from the forward pass batch to (agent_id, episode_id,
# env_id) tuple.
self.forward_pass_index_to_agent_info = {}
self.agent_key_to_forward_pass_index = {}
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, chunk_num: int,
init_obs: TensorType) -> None:
"""Adds a single initial observation (after env.reset()) to the buffer.
Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the
initial observation for.
env_id (EnvID): The env ID to which `init_obs` belongs.
chunk_num (int): The time-chunk number (0-based). Some episodes
may last for longer than self.num_timesteps and therefore
have to be chopped into chunks.
init_obs (TensorType): Initial observation (after env.reset()).
"""
agent_key = (agent_id, episode_id, chunk_num)
agent_slot = self.agent_slot_cursor
self.agent_key_to_slot[agent_key] = agent_slot
self.agent_key_to_chunk_num[agent_key[:2]] = chunk_num
self.slot_to_agent_key[agent_slot] = agent_key
self._next_agent_slot()
if SampleBatch.OBS not in self.buffers:
self._build_buffers(
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
"env_id": env_id,
})
if self.time_major:
self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \
init_obs
else:
self.buffers[SampleBatch.OBS][agent_slot, self.shift_before-1] = \
init_obs
self.agent_key_to_timestep[agent_key] = self.shift_before
self._add_to_next_inference_call(agent_key, env_id, agent_slot,
self.shift_before - 1)
def add_action_reward_next_obs(
self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID,
agent_done: bool, values: Dict[str, TensorType]) -> None:
"""Add the given dictionary (row) of values to this batch.
Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
values for.
agent_id (AgentID): Unique ID for the agent we are adding the
values for.
env_id (EnvID): The env ID to which the given data belongs.
agent_done (bool): Whether next obs should not be used for an
upcoming inference call. Default: False = next-obs should be
used for upcoming inference.
values (Dict[str, TensorType]): Data dict (interpreted as a single
row) to be added to buffer. Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
"""
assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values
and SampleBatch.NEXT_OBS in values
and SampleBatch.DONES in values)
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
chunk_num = self.agent_key_to_chunk_num[(agent_id, episode_id)]
agent_key = (agent_id, episode_id, chunk_num)
agent_slot = self.agent_key_to_slot[agent_key]
ts = self.agent_key_to_timestep[agent_key]
for k, v in values.items():
if k not in self.buffers:
self._build_buffers(single_row=values)
if self.time_major:
self.buffers[k][ts, agent_slot] = v
else:
self.buffers[k][agent_slot, ts] = v
self.agent_key_to_timestep[agent_key] += 1
# Time-axis is "full" -> Cut-over to new chunk (only if not DONE).
if self.agent_key_to_timestep[
agent_key] - self.shift_before == self.num_timesteps and \
not values[SampleBatch.DONES]:
self._new_chunk_from(agent_slot, agent_key,
self.agent_key_to_timestep[agent_key])
self.timesteps_since_last_reset += 1
if not agent_done:
self._add_to_next_inference_call(agent_key, env_id, agent_slot, ts)
def get_inference_input_dict(self, view_reqs: Dict[str, ViewRequirement]
) -> Dict[str, TensorType]:
"""Returns an input_dict for an (inference) forward pass.
The input_dict can then be used for action computations inside a
Policy via `Policy.compute_actions_from_input_dict()`.
Args:
view_reqs (Dict[str, ViewRequirement]): The view requirements
dict to use.
Returns:
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
for inference/training.
Examples:
>>> obs, r, done, info = env.step(action)
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
... "action": action, "obs": obs, "reward": r, "done": done
... })
>>> input_dict = collector.get_inference_input_dict(policy.model)
>>> action = policy.compute_actions_from_input_dict(input_dict)
>>> # repeat
"""
input_dict = {}
for view_col, view_req in view_reqs.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
if data_col not in self.buffers:
self._build_buffers({data_col: view_req.space.sample()})
indices = self.forward_pass_indices
if self.time_major:
input_dict[view_col] = self.buffers[data_col][indices]
else:
if isinstance(view_req.shift, (list, tuple)):
time_indices = \
np.array(view_req.shift) + np.array(indices[0])
input_dict[view_col] = self.buffers[data_col][indices[1],
time_indices]
else:
input_dict[view_col] = \
self.buffers[data_col][indices[1], indices[0]]
self._reset_inference_call()
return input_dict
def get_postprocessing_sample_batches(
self,
episode: MultiAgentEpisode,
view_reqs: Dict[str, ViewRequirement]) -> \
Dict[AgentID, SampleBatch]:
"""Returns a SampleBatch object ready for postprocessing.
Args:
episode (MultiAgentEpisode): The MultiAgentEpisode object to
get the to-be-postprocessed SampleBatches for.
view_reqs (Dict[str, ViewRequirement]): The view requirements dict
to use for creating the SampleBatch from our buffers.
Returns:
Dict[AgentID, SampleBatch]: The sample batch objects to be passed
to `Policy.postprocess_trajectory()`.
"""
# Loop through all agents and create a SampleBatch
# (as "view"; no copying).
# Construct the SampleBatch-dict.
sample_batch_data = {}
range_ = self.agent_slot_cursor - self.sample_batch_offset
if range_ < 0:
range_ = self.num_agents + range_
for i in range(range_):
agent_slot = self.sample_batch_offset + i
if agent_slot >= self.num_agents:
agent_slot = agent_slot % self.num_agents
# Do not postprocess the same slot twice.
if self.postprocessed_agents[agent_slot]:
continue
agent_key = self.slot_to_agent_key[agent_slot]
# Skip other episodes (if episode provided).
if episode and agent_key[1] != episode.episode_id:
continue
end = self.agent_key_to_timestep[agent_key]
# Do not build any empty SampleBatches.
if end == self.shift_before:
continue
self.postprocessed_agents[agent_slot] = True
assert agent_key not in sample_batch_data
sample_batch_data[agent_key] = {}
batch = sample_batch_data[agent_key]
for view_col, view_req in view_reqs.items():
data_col = view_req.data_col or view_col
# Skip columns that will only get added through postprocessing
# (these may not even exist yet).
if data_col not in self.buffers:
continue
shift = view_req.shift
if data_col == SampleBatch.OBS:
shift -= 1
batch[view_col] = self.buffers[data_col][
self.shift_before + shift:end + shift, agent_slot]
batches = {}
for agent_key, data in sample_batch_data.items():
batches[agent_key] = SampleBatch(data)
return batches
def get_train_sample_batch_and_reset(self, view_reqs) -> SampleBatch:
"""Returns the accumulated sample batche for this policy.
This is usually called to collect samples for policy training.
Returns:
SampleBatch: Returns the accumulated sample batch for this
policy.
"""
seq_lens_w_0s = [
self.agent_key_to_timestep[k] - self.shift_before
for k in self.slot_to_agent_key if k is not None
]
# We have an agent-axis buffer "rollover" (new SampleBatch will be
# built from last n agent records plus first m agent records in
# buffer).
if self.agent_slot_cursor < self.sample_batch_offset:
rollover = -(self.num_agents - self.sample_batch_offset)
seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover]
first_zero_len = len(seq_lens_w_0s)
if seq_lens_w_0s[-1] == 0:
first_zero_len = seq_lens_w_0s.index(0)
# Assert that all zeros lie at the end of the seq_lens array.
assert all(seq_lens_w_0s[i] == 0
for i in range(first_zero_len, len(seq_lens_w_0s)))
t_start = self.shift_before
t_end = t_start + self.num_timesteps
# The agent_slot cursor that points to the newest agent-slot that
# actually already has at least 1 timestep of data (thus it excludes
# just-rolled over chunks (which only have the initial obs in them)).
valid_agent_cursor = \
(self.agent_slot_cursor -
(len(seq_lens_w_0s) - first_zero_len)) % self.num_agents
# Construct the view dict.
view = {}
for view_col, view_req in view_reqs.items():
data_col = view_req.data_col or view_col
assert data_col in self.buffers
# For OBS, indices must be shifted by -1.
shift = view_req.shift
shift += 0 if data_col != SampleBatch.OBS else -1
# If agent_slot has been rolled-over to beginning, we have to copy
# here.
if valid_agent_cursor < self.sample_batch_offset:
time_slice = self.buffers[data_col][t_start + shift:t_end +
shift]
one_ = time_slice[:, self.sample_batch_offset:]
two_ = time_slice[:, :valid_agent_cursor]
if torch and isinstance(time_slice, torch.Tensor):
view[view_col] = torch.cat([one_, two_], dim=1)
else:
view[view_col] = np.concatenate([one_, two_], axis=1)
else:
view[view_col] = \
self.buffers[data_col][
t_start + shift:t_end + shift,
self.sample_batch_offset:valid_agent_cursor]
# Copy all still ongoing trajectories to new agent slots
# (including the ones that just started (are seq_len=0)).
new_chunk_args = []
for i, seq_len in enumerate(seq_lens_w_0s):
if seq_len < self.num_timesteps:
agent_slot = (self.sample_batch_offset + i) % self.num_agents
if not self.buffers[SampleBatch.
DONES][seq_len - 1 +
self.shift_before][agent_slot]:
agent_key = self.slot_to_agent_key[agent_slot]
new_chunk_args.append(
(agent_slot, agent_key,
self.agent_key_to_timestep[agent_key]))
# Cut out all 0 seq-lens.
seq_lens = seq_lens_w_0s[:first_zero_len]
batch = SampleBatch(
view, _seq_lens=np.array(seq_lens), _time_major=self.time_major)
# Reset everything for new data.
self.postprocessed_agents = [False] * self.num_agents
self.agent_key_to_slot.clear()
self.agent_key_to_chunk_num.clear()
self.slot_to_agent_key = [None] * self.num_agents
self.agent_key_to_timestep.clear()
self.timesteps_since_last_reset = 0
self.forward_pass_size = 0
self.sample_batch_offset = self.agent_slot_cursor
for args in new_chunk_args:
self._new_chunk_from(*args)
return batch
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
"""Builds the internal data buffers based on a single given row.
This may be called several times in the lifetime of this instance
to add new columns to the buffer. Columns in `single_row` that already
exist in the buffer will be ignored.
Args:
single_row (Dict[str, TensorType]): A single datarow with one or
more columns (str as key, np.ndarray|tensor as data) to be used
as template to build the pre-allocated buffer.
"""
time_size = self.num_timesteps + self.shift_before + self.shift_after
for col, data in single_row.items():
if col in self.buffers:
continue
base_shape = (time_size, self.num_agents) if self.time_major else \
(self.num_agents, time_size)
# Python primitive -> np.array.
if isinstance(data, (int, float, bool)):
t_ = type(data)
dtype = np.float32 if t_ == float else \
np.int32 if type(data) == int else np.bool_
self.buffers[col] = np.zeros(shape=base_shape, dtype=dtype)
# np.ndarray, torch.Tensor, or tf.Tensor.
else:
shape = base_shape + data.shape
dtype = data.dtype
if torch and isinstance(data, torch.Tensor):
self.buffers[col] = torch.zeros(
*shape, dtype=dtype, device=data.device)
elif tf and isinstance(data, tf.Tensor):
self.buffers[col] = tf.zeros(shape=shape, dtype=dtype)
else:
self.buffers[col] = np.zeros(shape=shape, dtype=dtype)
def _next_agent_slot(self):
"""Starts a new agent slot at the end of the agent-axis.
Also makes sure, the new slot is not taken yet.
"""
self.agent_slot_cursor += 1
if self.agent_slot_cursor >= self.num_agents:
self.agent_slot_cursor = 0
# Just make sure, there is space in our buffer.
assert self.slot_to_agent_key[self.agent_slot_cursor] is None
def _new_chunk_from(self, agent_slot, agent_key, timestep):
"""Creates a new time-window (chunk) given an agent.
The agent may already have an unfinished episode going on (in a
previous chunk). The end of that previous chunk will be copied to the
beginning of the new one for proper data-shift handling (e.g.
PREV_ACTIONS/REWARDS).
Args:
agent_slot (int): The agent to start a new chunk for (from an
ongoing episode (chunk)).
agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to
identify an active agent in some episode.
timestep (int): The timestep in the old chunk being continued.
"""
new_agent_slot = self.agent_slot_cursor
# Increase chunk num by 1.
new_agent_key = agent_key[:2] + (agent_key[2] + 1, )
# Copy relevant timesteps at end of old chunk into new one.
if self.time_major:
for k in self.buffers.keys():
self.buffers[k][0:self.shift_before, new_agent_slot] = \
self.buffers[k][
timestep - self.shift_before:timestep, agent_slot]
else:
for k in self.buffers.keys():
self.buffers[k][new_agent_slot, 0:self.shift_before] = \
self.buffers[k][
agent_slot, timestep - self.shift_before:timestep]
self.agent_key_to_slot[new_agent_key] = new_agent_slot
self.agent_key_to_chunk_num[new_agent_key[:2]] = new_agent_key[2]
self.slot_to_agent_key[new_agent_slot] = new_agent_key
self._next_agent_slot()
self.agent_key_to_timestep[new_agent_key] = self.shift_before
def _add_to_next_inference_call(self, agent_key, env_id, agent_slot,
timestep):
"""Registers given T and B (agent_slot) for get_inference_input_dict.
Calling `get_inference_input_dict` will produce an input_dict (for
Policy.compute_actions_from_input_dict) with all registered agent/time
indices and then automatically reset the registry.
Args:
agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to
identify an active agent in some episode.
env_id (EnvID): The env ID of the given agent.
agent_slot (int): The agent_slot to register (B axis).
timestep (int): The timestep to register (T axis).
"""
idx = self.forward_pass_size
self.forward_pass_index_to_agent_info[idx] = (agent_key[0],
agent_key[1], env_id)
self.agent_key_to_forward_pass_index[agent_key[:2]] = idx
if self.forward_pass_size == 0:
self.forward_pass_indices[0].clear()
self.forward_pass_indices[1].clear()
self.forward_pass_indices[0].append(timestep)
self.forward_pass_indices[1].append(agent_slot)
self.forward_pass_size += 1
def _reset_inference_call(self):
"""Resets indices for the next inference call.
After calling this, new calls to `add_init_obs()` and
`add_action_reward_next_obs()` will count for the next input_dict
returned by `get_inference_input_dict()`.
"""
self.forward_pass_size = 0

View file

@ -9,13 +9,14 @@ from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
TYPE_CHECKING, Union
from ray.util.debug import log_once
from ray.rllib.evaluation.collectors.sample_collector import \
_SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import \
_SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.multi_agent_sample_collector import \
_MultiAgentSampleCollector
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.sample_collector import _SampleCollector
from ray.rllib.policy.policy import clip_action, Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.models.preprocessors import Preprocessor
@ -188,8 +189,9 @@ class SyncSampler(SamplerInput):
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats()
if _use_trajectory_view_api:
self.sample_collector = _MultiAgentSampleCollector(
policies, callbacks)
self.sample_collector = _SimpleListCollector(
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
rollout_fragment_length)
else:
self.sample_collector = None
@ -333,8 +335,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.observation_fn = observation_fn
self._use_trajectory_view_api = _use_trajectory_view_api
if _use_trajectory_view_api:
self.sample_collector = _MultiAgentSampleCollector(
policies, callbacks)
self.sample_collector = _SimpleListCollector(
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
rollout_fragment_length)
else:
self.sample_collector = None
@ -537,7 +540,6 @@ def _env_runner(
active_episodes: Dict[str, MultiAgentEpisode] = \
NewEpisodeDefaultDict(new_episode)
eval_results = None
while True:
perf_stats.iters += 1
@ -564,7 +566,6 @@ def _env_runner(
base_env=base_env,
policies=policies,
active_episodes=active_episodes,
prev_policy_outputs=eval_results,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
@ -572,13 +573,11 @@ def _env_runner(
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
rollout_fragment_length=rollout_fragment_length,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
perf_stats=perf_stats,
_sample_collector=_sample_collector,
)
else:
@ -601,7 +600,6 @@ def _env_runner(
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
perf_stats=perf_stats,
)
perf_stats.raw_obs_processing_time += time.time() - t1
for o in outputs:
@ -669,7 +667,6 @@ def _process_observations(
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
perf_stats: _PerfStats,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]:
"""Record new data from the environment and prepare for policy evaluation.
@ -682,8 +679,6 @@ def _process_observations(
SampleBatchBuilder object for recycling.
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
prev_policy_outputs (Dict[str,List]): The prev policy output dict
(by policy-id -> List[action, state outs, extra fetches]).
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
call.
@ -862,6 +857,7 @@ def _process_observations(
# and add it to "outputs".
if (all_agents_done and not multiple_episodes_in_batch) or \
batch_builder.count >= rollout_fragment_length:
batch_builder.postprocess_batch_so_far(episode)
outputs.append(batch_builder.build_and_reset(episode))
# Make sure postprocessor stays within one episode.
elif all_agents_done:
@ -887,10 +883,14 @@ def _process_observations(
episode=episode,
env_index=env_id,
)
# Horizon hit and we have a soft horizon (no hard env reset).
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
# Env actually ended OR horizon hit and no soft horizon ->
# Try hard env-reset.
else:
# Remove episode from active ones.
del active_episodes[env_id]
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
env_id)
@ -939,8 +939,6 @@ def _process_observations_w_trajectory_view_api(
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
active_episodes: Dict[str, MultiAgentEpisode],
prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
dict]],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]],
dones: Dict[EnvID, Dict[AgentID, bool]],
@ -948,13 +946,11 @@ def _process_observations_w_trajectory_view_api(
horizon: int,
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
rollout_fragment_length: int,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
perf_stats: _PerfStats,
_sample_collector: _SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]:
@ -964,41 +960,20 @@ def _process_observations_w_trajectory_view_api(
# Output objects.
active_envs: Set[EnvID] = set()
to_eval: Set[PolicyID] = set()
to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
rollout_fragment_length != float("inf") else 5000
# For each environment.
# For each (vectorized) sub-environment.
# type: EnvID, Dict[AgentID, EnvObsType]
for env_id, agent_obs in unfiltered_obs.items():
for env_id, all_agents_obs in unfiltered_obs.items():
is_new_episode: bool = env_id not in active_episodes
episode: MultiAgentEpisode = active_episodes[env_id]
if not is_new_episode:
_sample_collector.episode_step(episode.episode_id)
episode.length += 1
_sample_collector.count += 1
episode._add_agent_rewards(rewards[env_id])
if (_sample_collector.total_env_steps() > large_batch_threshold
and log_once("large_batch_warning")):
logger.warning(
"More than {} observations for {} env steps ".format(
_sample_collector.total_env_steps(),
_sample_collector.count) +
"are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some point. "
"Note: In multi-agent environments, `rollout_fragment_length` "
"sets the batch size based on (across-agents) environment "
"steps, not the steps of individual agents, which can result "
"in unexpectedly large batches." +
("Also, you may be in evaluation waiting for your Env to "
"terminate (batch_mode=`complete_episodes`). Make sure it "
"does at some point."
if not multiple_episodes_in_batch else ""))
# Check episode termination conditions.
if dones[env_id]["__all__"] or episode.length >= horizon:
hit_horizon = (episode.length >= horizon
@ -1023,19 +998,19 @@ def _process_observations_w_trajectory_view_api(
# Custom observation function is applied before preprocessing.
if observation_fn:
agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=agent_obs,
all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=all_agents_obs,
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
if not isinstance(agent_obs, dict):
if not isinstance(all_agents_obs, dict):
raise ValueError(
"observe() must return a dict of agent observations")
# For each agent in the environment.
# type: AgentID, EnvObsType
for agent_id, raw_obs in agent_obs.items():
for agent_id, raw_obs in all_agents_obs.items():
assert agent_id != "__all__"
policy_id: PolicyID = episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(preprocessors,
@ -1058,38 +1033,41 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode.episode_id, agent_id,
env_id, policy_id, filtered_obs)
_sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, filtered_obs)
else:
rc = _sample_collector.policy_sample_collectors[policy_id]
eval_idx = rc.agent_key_to_forward_pass_index[(
agent_id, episode.episode_id)]
# Add actions, rewards, next-obs to collectors.
values_dict = {
"t": episode.length - 1,
"eps_id": episode.episode_id,
"env_id": env_id,
"agent_index": episode._agent_index(agent_id),
# Action (slot 0) taken at timestep t.
"actions": prev_policy_outputs[policy_id][0][eval_idx],
"actions": episode.last_action_for(agent_id),
# Reward received after taking a at timestep t.
"rewards": rewards[env_id][agent_id],
# After taking a, did we reach terminal?
# After taking action=a, did we reach terminal?
"dones": (False if (no_done_at_end
or (hit_horizon and soft_horizon)) else
agent_done),
# Next observation.
"new_obs": filtered_obs,
}
# TODO: (sven) add env infos to buffers as well.
for k, v in prev_policy_outputs[policy_id][2].items():
values_dict[k] = v[eval_idx]
for i, v in enumerate(prev_policy_outputs[policy_id][1]):
values_dict["state_out_{}".format(i)] = v[eval_idx]
# Add extra-action-fetches to collectors.
values_dict.update(**episode.last_pi_info_for(agent_id))
_sample_collector.add_action_reward_next_obs(
episode.episode_id, agent_id, env_id, policy_id,
agent_done, values_dict)
if not agent_done:
to_eval.add(policy_id)
item = PolicyEvalData(
env_id, agent_id, filtered_obs, infos[env_id].get(
agent_id, {}), None if last_observation is None else
episode.rnn_state_for(agent_id), None
if last_observation is None else
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0)
to_eval[policy_id].append(item)
# Invoke the step callback after the step is logged to the episode
callbacks.on_episode_step(
@ -1098,36 +1076,22 @@ def _process_observations_w_trajectory_view_api(
episode=episode,
env_index=env_id)
# Cut the batch if ...
# - all-agents-done and not packing multiple episodes into one
# (batch_mode="complete_episodes")
# - or if we've exceeded the rollout_fragment_length.
if _sample_collector.has_non_postprocessed_data():
# Sanity check, whether all agents have done=True, if done[__all__]
# is True.
if dones[env_id]["__all__"] and not no_done_at_end:
_sample_collector.check_missing_dones(
episode_id=episode.episode_id)
# Reached end of episode and we are not allowed to pack the
# next episode into the same SampleBatch -> Build the SampleBatch
# and add it to "outputs".
if (all_agents_done and not multiple_episodes_in_batch) or \
_sample_collector.count >= rollout_fragment_length:
# TODO: (sven) Case: rollout_fragment_length reached: Do not
# store any data in `episode` anymore
# (useless for get_view_requirements when t<<-1, e.g.
# attention), but keep last episode data around in
# SampleBatchBuilder
# to be able to still reference into it
# should a model require this.
outputs.append(_sample_collector.get_multi_agent_batch_and_reset())
# Episode is done for all agents
# (dones[__all__] == True or hit horizon).
# Make sure postprocessor stays within one episode.
elif all_agents_done:
_sample_collector.postprocess_trajectories_so_far(episode)
# Episode is done.
if all_agents_done:
is_done = dones[env_id]["__all__"]
check_dones = is_done and not no_done_at_end
_sample_collector.postprocess_episode(
episode, is_done=is_done, check_dones=check_dones)
# We are not allowed to pack the next episode into the same
# SampleBatch (batch_mode=complete_episodes) -> Build the
# MultiAgentBatch from a single episode and add it to "outputs".
if not multiple_episodes_in_batch:
ma_sample_batch = \
_sample_collector.build_multi_agent_batch(episode.length)
outputs.append(ma_sample_batch)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
if getattr(p, "exploration", None) is not None:
@ -1144,44 +1108,56 @@ def _process_observations_w_trajectory_view_api(
episode=episode,
env_index=env_id,
)
# Horizon hit and we have a soft horizon (no hard env reset).
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
else:
del active_episodes[env_id]
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list.
if resetted_obs is None:
if horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
elif resetted_obs != ASYNC_RESET_RETURN:
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll
episode: MultiAgentEpisode = active_episodes[env_id]
# If reset is async, we will get its result in some future poll.
elif resetted_obs != ASYNC_RESET_RETURN:
new_episode: MultiAgentEpisode = active_episodes[env_id]
if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs,
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
episode=new_episode)
# type: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = episode.policy_for(agent_id)
policy_id: PolicyID = new_episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(
preprocessors, policy_id).transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
obs_filters, policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
new_episode._set_last_observation(agent_id, filtered_obs)
# Add initial obs to buffer.
_sample_collector.add_init_obs(episode.episode_id,
agent_id, env_id, policy_id,
filtered_obs)
to_eval.add(policy_id)
_sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id, filtered_obs)
item = PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id), None, 0.0)
to_eval[policy_id].append(item)
# Try to build something.
if multiple_episodes_in_batch:
sample_batch = \
_sample_collector.try_build_truncated_episode_multi_agent_batch()
if sample_batch is not None:
outputs.append(sample_batch)
return active_envs, to_eval, outputs
@ -1306,7 +1282,7 @@ def _do_policy_eval_w_trajectory_view_api(
logger.info("Inputs to compute_actions():\n\n{}\n".format(
summarize(to_eval)))
for policy_id in to_eval:
for policy_id in to_eval.keys():
policy: Policy = _get_or_raise(policies, policy_id)
input_dict = _sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = \
@ -1373,7 +1349,7 @@ def _process_policy_eval_results(
actions_to_send[env_id] = {} # at minimum send empty dict
# type: PolicyID, List[PolicyEvalData]
for policy_id in to_eval:
for policy_id, eval_data in to_eval.items():
actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions)
@ -1385,10 +1361,11 @@ def _process_policy_eval_results(
if isinstance(actions, list):
actions = np.array(actions)
# Add RNN state info.
eval_data = None
if not _use_trajectory_view_api:
eval_data = to_eval[policy_id]
# Store RNN state ins/outs and extra-action fetches to episode.
if _use_trajectory_view_api:
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
else:
rnn_in_cols: StateBatch = _to_column_format(
[t.rnn_state for t in eval_data])
@ -1413,14 +1390,6 @@ def _process_policy_eval_results(
else:
clipped_action = action
# Trajectory View API: Do not store data directly in episode
# (entire episode is stored in Trajectory and kept until
# end of episode).
if _use_trajectory_view_api:
agent_id, episode_id, env_id = \
_sample_collector.policy_sample_collectors[
policy_id].forward_pass_index_to_agent_info[i]
else:
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
episode: MultiAgentEpisode = active_episodes[env_id]
@ -1430,8 +1399,8 @@ def _process_policy_eval_results(
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
episode._set_last_action(agent_id,
off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)

View file

@ -9,8 +9,9 @@ from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwarePolicy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import framework_iterator
from ray.rllib.utils.test_utils import framework_iterator, check
class TestTrajectoryViewAPI(unittest.TestCase):
@ -30,9 +31,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.training_view_requirements
assert len(view_req_model) == 1
assert len(view_req_policy) == 10
view_req_policy = policy.view_requirements
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 11, view_req_policy
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
@ -62,9 +63,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.training_view_requirements
assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates
assert len(view_req_policy) == 16
view_req_policy = policy.view_requirements
assert len(view_req_model) == 7, view_req_model
assert len(view_req_policy) == 17, view_req_policy
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
@ -90,7 +91,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].shift == 1
trainer.stop()
def test_traj_view_lstm_performance(self):
def test_traj_view_simple_performance(self):
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
"""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
@ -102,17 +103,15 @@ class TestTrajectoryViewAPI(unittest.TestCase):
from ray.tune import register_env
register_env("ma_env", lambda c: RandomMultiAgentEnv({
"num_agents": 2,
"p_done": 0.01,
"p_done": 0.0,
"max_episode_len": 104,
"action_space": action_space,
"observation_space": obs_space
}))
config["num_workers"] = 3
config["num_envs_per_worker"] = 8
config["num_sgd_iter"] = 6
config["model"]["use_lstm"] = True
config["model"]["lstm_use_prev_action_reward"] = True
config["model"]["max_seq_len"] = 100
config["num_sgd_iter"] = 1 # Put less weight on training.
policies = {
"pol0": (None, obs_space, action_space, {}),
@ -125,72 +124,80 @@ class TestTrajectoryViewAPI(unittest.TestCase):
"policies": policies,
"policy_mapping_fn": policy_fn,
}
num_iterations = 1
num_iterations = 2
# Only works in torch so far.
for _ in framework_iterator(config, frameworks="torch"):
print("w/ traj. view API (and time-major)")
print("w/ traj. view API")
config["_use_trajectory_view_api"] = True
config["model"]["_time_major"] = True
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_w = 0.0
sampler_perf = {}
sampler_perf_w = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
ts = out["timesteps_total"]
sampler_perf_ = out["sampler_perf"]
sampler_perf = {
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
sampler_perf_w = {
k:
sampler_perf_w.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts)
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / 1000
delta = out["timers"]["learn_time_ms"] / ts
learn_time_w += delta
print("{}={}s".format(i, delta))
sampler_perf = {
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf.items()
sampler_perf_w = {
k: sampler_perf_w[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf_w.items()
}
duration_w = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_w, sampler_perf, learn_time_w / num_iterations))
duration_w, sampler_perf_w,
learn_time_w / num_iterations))
trainer.stop()
print("w/o traj. view API (and w/o time-major)")
print("w/o traj. view API")
config["_use_trajectory_view_api"] = False
config["model"]["_time_major"] = False
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_wo = 0.0
sampler_perf = {}
sampler_perf_wo = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
ts = out["timesteps_total"]
sampler_perf_ = out["sampler_perf"]
sampler_perf = {
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
sampler_perf_wo = {
k: sampler_perf_wo.get(k, 0.0) +
(sampler_perf_[k] * 1000 / ts)
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / 1000
delta = out["timers"]["learn_time_ms"] / ts
learn_time_wo += delta
print("{}={}s".format(i, delta))
sampler_perf = {
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf.items()
sampler_perf_wo = {
k: sampler_perf_wo[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf_wo.items()
}
duration_wo = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_wo, sampler_perf,
duration_wo, sampler_perf_wo,
learn_time_wo / num_iterations))
trainer.stop()
# Assert `_use_trajectory_view_api` is much faster.
# Assert `_use_trajectory_view_api` is faster.
self.assertLess(sampler_perf_w["mean_raw_obs_processing_ms"],
sampler_perf_wo["mean_raw_obs_processing_ms"])
self.assertLess(sampler_perf_w["mean_action_processing_ms"],
sampler_perf_wo["mean_action_processing_ms"])
self.assertLess(duration_w, duration_wo)
self.assertLess(learn_time_w, learn_time_wo * 0.6)
def test_traj_view_lstm_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(2, ))
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
rollout_fragment_length = 200
assert rollout_fragment_length % max_seq_len == 0
policies = {
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
}
@ -198,77 +205,162 @@ class TestTrajectoryViewAPI(unittest.TestCase):
def policy_fn(agent_id):
return "pol0"
rollout_worker = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config={
config = {
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_fn,
},
"_use_trajectory_view_api": True,
"model": {
"use_lstm": True,
"_time_major": True,
"max_seq_len": max_seq_len,
},
},
rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
rollout_fragment_length=rollout_fragment_length,
policy=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
for i in range(100):
pc = rollout_worker.sampler.sample_collector. \
policy_sample_collectors["pol0"]
sample_batch_offset_before = pc.sample_batch_offset
buffers = pc.buffers
result = rollout_worker.sample()
pol_batch = result.policy_batches["pol0"]
rollout_worker_wo_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": False}),
rollout_fragment_length=rollout_fragment_length,
policy=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
for iteration in range(20):
result = rollout_worker_w_api.sample()
check(result.count, rollout_fragment_length)
pol_batch_w = result.policy_batches["pol0"]
assert pol_batch_w.count >= rollout_fragment_length
analyze_rnn_batch(pol_batch_w, max_seq_len)
result = rollout_worker_wo_api.sample()
pol_batch_wo = result.policy_batches["pol0"]
check(pol_batch_w.data, pol_batch_wo.data)
def analyze_rnn_batch(batch, max_seq_len):
count = batch.count
self.assertTrue(result.count == 100)
self.assertTrue(pol_batch.count >= 100)
self.assertFalse(0 in pol_batch.seq_lens)
# Check prev_reward/action, next_obs consistency.
for t in range(max_seq_len):
obs_t = pol_batch["obs"][t]
r_t = pol_batch["rewards"][t]
if t > 0:
next_obs_t_m_1 = pol_batch["new_obs"][t - 1]
self.assertTrue((obs_t == next_obs_t_m_1).all())
if t < max_seq_len - 1:
prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1]
self.assertTrue((r_t == prev_rewards_t_p_1).all())
for idx in range(count):
# If timestep tracked by batch, good.
if "t" in batch:
ts = batch["t"][idx]
# Else, ts
else:
ts = batch["obs"][idx][3]
obs_t = batch["obs"][idx]
a_t = batch["actions"][idx]
r_t = batch["rewards"][idx]
state_in_0 = batch["state_in_0"][idx]
state_in_1 = batch["state_in_1"][idx]
# Check the sanity of all the buffers in the un underlying
# PerPolicy collector.
for sample_batch_slot, agent_slot in enumerate(
range(sample_batch_offset_before, pc.sample_batch_offset)):
t_buf = buffers["t"][:, agent_slot]
obs_buf = buffers["obs"][:, agent_slot]
# Skip empty seqs at end (these won't be part of the batch
# and have been copied to new agent-slots (even if seq-len=0)).
if sample_batch_slot < len(pol_batch.seq_lens):
seq_len = pol_batch.seq_lens[sample_batch_slot]
# Make sure timesteps are always increasing within the seq.
assert all(t_buf[1] + j == n + 1
for j, n in enumerate(t_buf)
if j < seq_len and j != 0)
# Make sure all obs within seq are non-0.0.
assert all(
any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1))
# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][idx]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check seq-lens.
for agent_slot, seq_len in enumerate(pol_batch.seq_lens):
if seq_len < max_seq_len - 1:
# At least in the beginning, the next slots should always
# be empty (once all agent slots have been used once, these
# may be filled with "old" values (from longer sequences)).
if i < 10:
self.assertTrue(
(pol_batch["obs"][seq_len +
1][agent_slot] == 0.0).all())
print(end="")
self.assertFalse(
(pol_batch["obs"][seq_len][agent_slot] == 0.0).all())
# Check state-in/out and next-obs values.
if idx > 0:
next_obs_t_m_1 = batch["new_obs"][idx - 1]
state_out_0_t_m_1 = batch["state_out_0"][idx - 1]
state_out_1_t_m_1 = batch["state_out_1"][idx - 1]
# Same trajectory as for t-1 -> Should be able to match.
if (batch[SampleBatch.AGENT_INDEX][idx] ==
batch[SampleBatch.AGENT_INDEX][idx - 1]
and batch[SampleBatch.EPS_ID][idx] ==
batch[SampleBatch.EPS_ID][idx - 1]):
assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx]
assert (obs_t == next_obs_t_m_1).all()
assert (state_in_0 == state_out_0_t_m_1).all()
assert (state_in_1 == state_out_1_t_m_1).all()
# Different trajectory.
else:
assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx]
assert not (obs_t == next_obs_t_m_1).all()
assert not (state_in_0 == state_out_0_t_m_1).all()
assert not (state_in_1 == state_out_1_t_m_1).all()
# Check initial 0-internal states.
if ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
# Check initial 0-internal states (at ts=0).
if ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
# Check prev. a/r values.
if idx < count - 1:
prev_actions_t_p_1 = batch["prev_actions"][idx + 1]
prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1]
# Same trajectory as for t+1 -> Should be able to match.
if batch[SampleBatch.AGENT_INDEX][idx] == \
batch[SampleBatch.AGENT_INDEX][idx + 1] and \
batch[SampleBatch.EPS_ID][idx] == \
batch[SampleBatch.EPS_ID][idx + 1]:
assert (a_t == prev_actions_t_p_1).all()
assert r_t == prev_rewards_t_p_1
# Different (new) trajectory. Assume t-1 (prev-a/r) to be
# always 0.0s. [3]=ts
elif ts == 0:
assert (prev_actions_t_p_1 == 0).all()
assert prev_rewards_t_p_1 == 0.0
pad_batch_to_sequences_of_same_size(
batch,
max_seq_len=max_seq_len,
shuffle=False,
batch_divisibility_req=1)
# Check after seq-len 0-padding.
cursor = 0
for i, seq_len in enumerate(batch["seq_lens"]):
state_in_0 = batch["state_in_0"][i]
state_in_1 = batch["state_in_1"][i]
for j in range(seq_len):
k = cursor + j
ts = batch["t"][k]
obs_t = batch["obs"][k]
a_t = batch["actions"][k]
r_t = batch["rewards"][k]
# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][k]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check state-in/out and next-obs values.
if j > 0:
next_obs_t_m_1 = batch["new_obs"][k - 1]
# state_out_0_t_m_1 = batch["state_out_0"][k - 1]
# state_out_1_t_m_1 = batch["state_out_1"][k - 1]
# Always same trajectory as for t-1.
assert batch["unroll_id"][k - 1] == batch["unroll_id"][k]
assert (obs_t == next_obs_t_m_1).all()
# assert (state_in_0 == state_out_0_t_m_1).all())
# assert (state_in_1 == state_out_1_t_m_1).all())
# Check initial 0-internal states.
elif ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
for j in range(seq_len, max_seq_len):
k = cursor + j
obs_t = batch["obs"][k]
a_t = batch["actions"][k]
r_t = batch["rewards"][k]
assert (obs_t == 0.0).all()
assert (a_t == 0.0).all()
assert (r_t == 0.0).all()
cursor += max_seq_len
if __name__ == "__main__":

View file

@ -22,7 +22,7 @@ parser.add_argument("--stop-reward", type=float, default=150)
if __name__ == "__main__":
args = parser.parse_args()
ray.init(local_mode=True)
ray.init()
ModelCatalog.register_custom_model(
"bn_model", TorchBatchNormModel if args.torch else BatchNormModel)

View file

@ -50,7 +50,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop)
results = tune.run(args.run, config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -108,7 +108,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run("PPO", config=config, stop=stop)
results = tune.run("PPO", config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -9,7 +9,6 @@ For PyTorch / TF eager mode, use the --torch and --eager flags.
import argparse
import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG
@ -21,18 +20,13 @@ parser.add_argument(
"--framework", choices=["tf", "tfe", "torch"], default="tf")
if __name__ == "__main__":
ray.init(local_mode=True)
args = parser.parse_args()
if args.framework == "torch":
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
else:
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
tune.run(
"PG",
stop={
"timesteps_total": 1,
},
config={
config = {
"framework": args.framework,
"env": SimpleRPG,
"rollout_fragment_length": 1,
@ -41,5 +35,10 @@ if __name__ == "__main__":
"model": {
"custom_model": "my_model",
},
},
)
}
stop = {
"timesteps_total": 1,
}
tune.run("PG", config=config, stop=stop, verbose=1)

View file

@ -55,7 +55,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop)
results = tune.run(args.run, config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -29,7 +29,7 @@ class DebugCounterEnv(gym.Env):
class MultiAgentDebugCounterEnv(MultiAgentEnv):
def __init__(self, config):
self.num_agents = config["num_agents"]
self.p_done = config.get("p_done", 0.02)
self.base_episode_len = config.get("base_episode_len", 103)
# Actions are always:
# (episodeID, envID) as floats.
self.action_space = \
@ -45,6 +45,7 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv):
self.dones = set()
def reset(self):
self.timesteps = [0] * self.num_agents
self.dones = set()
return {
i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
@ -57,9 +58,8 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv):
self.timesteps[i] += 1
obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
rew[i] = self.timesteps[i] % 3
done[i] = bool(
np.random.choice(
[True, False], p=[self.p_done, 1.0 - self.p_done]))
done[i] = True if self.timesteps[i] > self.base_episode_len + i \
else False
if done[i]:
self.dones.add(i)
done["__all__"] = len(self.dones) == self.num_agents

View file

@ -96,7 +96,7 @@ if __name__ == "__main__":
"framework": "torch" if args.torch else "tf",
}
results = tune.run("PPO", stop=stop, config=config)
results = tune.run("PPO", stop=stop, config=config, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -73,7 +73,7 @@ class BinaryAutoregressiveDistribution(ActionDistribution):
return a2_dist
@staticmethod
def required_model_output_shape(self, model_config):
def required_model_output_shape(action_space, model_config):
return 16 # controls model output feature vector size
@ -143,5 +143,5 @@ class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
return a2_dist
@staticmethod
def required_model_output_shape(self, model_config):
def required_model_output_shape(action_space, model_config):
return 16 # controls model output feature vector size

View file

@ -180,7 +180,7 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
def forward(self, input_dict, state, seq_lens):
# Set the correct train-mode for our hidden module (only important
# b/c we have some batch-norm layers).
self._hidden_layers.train(mode=input_dict["is_training"])
self._hidden_layers.train(mode=input_dict.get("is_training", False))
self._hidden_out = self._hidden_layers(input_dict["obs"])
logits = self._logits(self._hidden_out)
return logits, []

View file

@ -1,9 +1,11 @@
from gym.spaces import Box
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@ -101,8 +103,18 @@ class TorchRNNModel(TorchRNN, nn.Module):
# Holds the current "base" output (before logits layer).
self._features = None
# Add state-ins to this model's view.
for i in range(2):
self.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
space=Box(-1.0, 1.0, shape=(self.lstm_state_size,)))
@override(ModelV2)
def get_initial_state(self):
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
# View API is supported across all of RLlib.
# Place hidden states on same device as model.
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),

View file

@ -47,10 +47,7 @@ if __name__ == "__main__":
"timesteps_total": args.stop_timesteps,
}
results = tune.run(
"PG",
stop=stop,
config={
config = {
"env": "multi_agent_cartpole",
"multiagent": {
"policies": {
@ -63,8 +60,9 @@ if __name__ == "__main__":
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
},
"framework": "torch" if args.torch else "tf",
},
)
}
results = tune.run("PG", config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -52,7 +52,7 @@ if __name__ == "__main__":
"timesteps_total": args.stop_timesteps,
}
results = tune.run(args.run, config=config, stop=stop)
results = tune.run(args.run, config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -1,3 +1,4 @@
from gym.spaces import Box
import numpy as np
from ray.rllib.examples.policy.random_policy import RandomPolicy
@ -13,8 +14,7 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.episode_id = None
self.env_id = None
self.state_space = Box(-1.0, 1.0, (1, ))
class _fake_model:
pass
@ -22,15 +22,25 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
self.model = _fake_model()
self.model.time_major = True
self.model.inference_view_requirements = {
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
}
self.training_view_requirements = dict(
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i), shift=-1, space=self.state_space)
self.model.inference_view_requirements[
"state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space)
self.view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1),
@ -50,17 +60,23 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
explore=None,
timestep=None,
**kwargs):
self.episode_id = input_dict[SampleBatch.EPS_ID][0]
self.env_id = input_dict["env_id"][0]
# Always return (episodeID, envID)
return [
np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"]
], [], {}
ts = input_dict["t"]
print(ts)
# Always return [episodeID, envID] as actions.
actions = np.array([[
input_dict[SampleBatch.AGENT_INDEX][i],
input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i]
] for i, _ in enumerate(input_dict["obs"])])
states = [
np.array([[ts[i]] for i in range(len(input_dict["obs"]))])
for _ in range(2)
]
return actions, states, {}
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0
sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0
return sample_batch

View file

@ -1,7 +1,10 @@
import gym
import numpy as np
import random
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.view_requirement import ViewRequirement
class AlwaysSameHeuristic(Policy):
@ -10,6 +13,12 @@ class AlwaysSameHeuristic(Policy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exploration = self._create_exploration()
self.view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
shift=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
})
def get_initial_state(self):
return [
@ -27,6 +36,9 @@ class AlwaysSameHeuristic(Policy):
info_batch=None,
episodes=None,
**kwargs):
if self.config["_use_trajectory_view_api"]:
return state_batches[0][0], [s[0] for s in state_batches], {}
else:
return state_batches[0], state_batches, {}

View file

@ -38,7 +38,7 @@ def run_same_policy(args, stop):
"framework": "torch" if args.torch else "tf",
}
results = tune.run("PG", config=config, stop=stop)
results = tune.run("PG", config=config, stop=stop, verbose=1)
if args.as_test:
# Check vs 0.0 as we are playing a zero-sum game.

View file

@ -63,6 +63,8 @@ class ModelV2:
SampleBatch.OBS: ViewRequirement(shift=0),
}
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
# View API is supported across all of RLlib.
@PublicAPI
def get_initial_state(self) -> List[np.ndarray]:
"""Get the initial recurrent state values for the model.

View file

@ -135,25 +135,20 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_)
self.inference_view_requirements.update(
dict(
**{
SampleBatch.OBS: ViewRequirement(shift=0),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space,
shift=-1),
}))
# Add prev-a/r to this model's view, if required.
if model_config["lstm_use_prev_action_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
# Add state-ins to this model's view.
for i in range(2):
self.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
self.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
@override(RecurrentNetwork)
def forward(self, input_dict, state, seq_lens):

View file

@ -5,6 +5,7 @@ import tree
from typing import Dict, List, Optional
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_torch
@ -74,7 +75,14 @@ class Policy(metaclass=ABCMeta):
# Child classes need to add their specific requirements here (usually
# a combination of a Model's inference_view_- and the
# Policy's loss function-requirements.
self.training_view_requirements = {}
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
}
@abstractmethod
@DeveloperAPI
@ -255,7 +263,23 @@ class Policy(metaclass=ABCMeta):
shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
# Default implementation just passes obs, prev-a/r, and states on to
# `self.compute_actions()`.
state_batches = [
s.unsqueeze(0)
if torch and isinstance(s, torch.Tensor) else np.expand_dims(s, 0)
for k, s in input_dict.items() if k[:9] == "state_in_"
]
return self.compute_actions(
input_dict[SampleBatch.OBS],
state_batches,
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
info_batch=None,
explore=explore,
timestep=timestep,
**kwargs,
)
@DeveloperAPI
def compute_log_likelihoods(

View file

@ -35,7 +35,6 @@ def pad_batch_to_sequences_of_same_size(
shuffle: bool = False,
batch_divisibility_req: int = 1,
feature_keys: Optional[List[str]] = None,
_use_trajectory_view_api: bool = False,
):
"""Applies padding to `batch` so it's choppable into same-size sequences.
@ -56,26 +55,7 @@ def pad_batch_to_sequences_of_same_size(
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
_use_trajectory_view_api (bool): Whether we are using the Trajectory
View API to collect and process samples.
"""
if _use_trajectory_view_api:
if batch.time_major is not None:
batch["seq_lens"] = torch.tensor(batch.seq_lens)
t = 0 if batch.time_major else 1
for col in batch.data.keys():
# Cut time-dim from states.
if "state_" in col[:6]:
batch[col] = batch[col][t]
# Flatten all other data.
else:
# Cut time-dim at `max_seq_len`.
if batch.time_major:
batch[col] = batch[col][:batch.max_seq_len]
batch[col] = batch[col].reshape((-1, ) +
batch[col].shape[2:])
return
if batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0

View file

@ -11,7 +11,6 @@ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
@ -31,8 +30,6 @@ logger = logging.getLogger(__name__)
class TorchPolicy(Policy):
"""Template for a PyTorch policy and loss to use with RLlib.
This is similar to TFPolicy, but for PyTorch.
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
@ -114,14 +111,7 @@ class TorchPolicy(Policy):
self.device = torch.device("cpu")
self.model = model.to(self.device)
# Combine view_requirements for Model and Policy.
self.training_view_requirements = dict(
**{
SampleBatch.ACTIONS: ViewRequirement(
space=self.action_space, shift=0),
SampleBatch.REWARDS: ViewRequirement(shift=0),
SampleBatch.DONES: ViewRequirement(shift=0),
},
**self.model.inference_view_requirements)
self.view_requirements.update(self.model.inference_view_requirements)
self.exploration = self._create_exploration()
self.unwrapped_model = model # used to support DistributedDataParallel
@ -202,9 +192,11 @@ class TorchPolicy(Policy):
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_" in k[:6]
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
# Calculate RNN sequence lengths.
seq_lens = np.array([1] * len(input_dict["obs"])) \
if state_batches else None
@ -217,7 +209,8 @@ class TorchPolicy(Policy):
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp)
extra_fetches[SampleBatch.ACTION_LOGP] = logp
return actions, state_out, extra_fetches
return convert_to_non_torch_type((actions, state_out,
extra_fetches))
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
explore, timestep):
@ -342,7 +335,6 @@ class TorchPolicy(Policy):
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
_use_trajectory_view_api=self.config["_use_trajectory_view_api"],
)
train_batch = self._lazy_tensor_dict(postprocessed_batch)
@ -479,7 +471,7 @@ class TorchPolicy(Policy):
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
return [
s.cpu().detach().numpy() for s in self.model.get_initial_state()
s.detach().cpu().numpy() for s in self.model.get_initial_state()
]
@override(Policy)

View file

@ -64,7 +64,7 @@ def build_torch_policy(
apply_gradients_fn: Optional[Callable[
[Policy, "torch.optim.Optimizer"], None]] = None,
mixins: Optional[List[type]] = None,
training_view_requirements_fn: Optional[Callable[[], Dict[
view_requirements_fn: Optional[Callable[[], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None):
"""Helper function for creating a torch policy class at runtime.
@ -159,7 +159,7 @@ def build_torch_policy(
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the TorchPolicy class.
training_view_requirements_fn (Callable[[],
view_requirements_fn (Callable[[],
Dict[str, ViewRequirement]]): An optional callable to retrieve
additional train view requirements for this policy.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
@ -226,9 +226,8 @@ def build_torch_policy(
get_batch_divisibility_req=get_batch_divisibility_req,
)
if callable(training_view_requirements_fn):
self.training_view_requirements.update(
training_view_requirements_fn(self))
if callable(view_requirements_fn):
self.view_requirements.update(view_requirements_fn(self))
if after_init:
after_init(self, obs_space, action_space, config)

View file

@ -29,7 +29,8 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0):
shift: Union[int, List[int]] = 0,
used_for_training: bool = True):
"""Initializes a ViewRequirement object.
Args:
@ -46,8 +47,12 @@ class ViewRequirement:
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
"""
self.data_col = data_col
self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.shift = shift
self.used_for_training = used_for_training

View file

@ -1,4 +1,3 @@
import numpy as np
import gym
from ray.rllib.utils.annotations import PublicAPI
@ -17,14 +16,9 @@ class Repeated(gym.Space):
"""
def __init__(self, child_space: gym.Space, max_len: int):
self.np_random = np.random.RandomState()
super().__init__()
self.child_space = child_space
self.max_len = max_len
super().__init__()
def seed(self, seed=None):
self.np_random = np.random.RandomState()
self.np_random.seed(seed)
def sample(self):
return [

View file

@ -34,7 +34,7 @@ def convert_to_non_torch_type(stats):
def mapping(item):
if isinstance(item, torch.Tensor):
return item.cpu().item() if len(item.size()) == 0 else \
item.cpu().detach().numpy()
item.detach().cpu().numpy()
else:
return item