[RLlib] Trajectory view API - 03 Fast LSTM + prev actions/rewards (#9950)

This commit is contained in:
Sven Mika 2020-08-21 12:35:16 +02:00 committed by GitHub
parent 92664249e8
commit e968b52cb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 1230 additions and 413 deletions

View file

@ -22,13 +22,14 @@
# (problems: 10min timeout, not respecting ray/ci/keep_alive.sh, or even
# `travis_wait n`, etc..).
# Our travis.yml file executes all these tests in 6 different jobs, which are:
# Our travis.yml file executes all these tests in 7 different jobs, which are:
# 1) everything in a) using tf2.x
# 2) everything in a) using tf1.x
# 3) everything in b) c) d) and e)
# 4) everything in g)
# 5) f), BUT only those tagged `tests_dir_A` to `tests_dir_L`
# 6) f), BUT only those tagged `tests_dir_M` to `tests_dir_Z`
# 3) everything in a) using torch
# 4) everything in b) c) d) and e)
# 5) everything in g)
# 6) f), BUT only those tagged `tests_dir_A` to `tests_dir_[some letter]`
# 7) f), BUT only those tagged `tests_dir_[some letter]` to `tests_dir_Z`
# --------------------------------------------------------------------
@ -1024,6 +1025,22 @@ py_test(
srcs = ["models/tests/test_attention_nets.py"]
)
# --------------------------------------------------------------------
# Evaluation components
# rllib/evaluation/
#
# Tag: evaluation
# --------------------------------------------------------------------
# mysteriously times out on travis.
#py_test(
# name = "evaluation/tests/test_trajectory_view_api",
# tags = ["evaluation"],
# size = "medium",
# srcs = ["evaluation/tests/test_trajectory_view_api.py"]
#)
# --------------------------------------------------------------------
# Optimizers and Memories
# rllib/execution/
@ -1059,13 +1076,6 @@ py_test(
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
)
py_test(
name = "policy/tests/test_trajectory_view_api",
tags = ["policy"],
size = "small",
srcs = ["policy/tests/test_trajectory_view_api.py"]
)
# --------------------------------------------------------------------
# Utils:
# rllib/utils/

View file

@ -1,13 +1,16 @@
from typing import Dict
from typing import Dict, TYPE_CHECKING
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import AgentID, PolicyID
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@PublicAPI
class DefaultCallbacks:
@ -27,7 +30,7 @@ class DefaultCallbacks:
"a class extending rllib.agents.callbacks.DefaultCallbacks")
self.legacy_callbacks = legacy_callbacks_dict or {}
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, **kwargs):
"""Callback run on the rollout worker before each episode starts.
@ -52,7 +55,7 @@ class DefaultCallbacks:
"episode": episode,
})
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv,
episode: MultiAgentEpisode, **kwargs):
"""Runs on each episode step.
@ -73,7 +76,7 @@ class DefaultCallbacks:
"episode": episode
})
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, **kwargs):
"""Runs when an episode is done.
@ -99,7 +102,7 @@ class DefaultCallbacks:
})
def on_postprocess_trajectory(
self, worker: RolloutWorker, episode: MultiAgentEpisode,
self, worker: "RolloutWorker", episode: MultiAgentEpisode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs):
@ -133,7 +136,7 @@ class DefaultCallbacks:
"all_pre_batches": original_batches,
})
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch,
**kwargs):
"""Called at the end RolloutWorker.sample().

View file

@ -117,7 +117,10 @@ def ppo_surrogate_loss(policy, model, dist_class, train_batch):
mask = None
if state:
max_seq_len = torch.max(train_batch["seq_lens"])
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = sequence_mask(
train_batch["seq_lens"],
max_seq_len,
time_major=model.is_time_major())
mask = torch.reshape(mask, [-1])
policy.loss_obj = PPOLoss(
@ -221,6 +224,12 @@ def training_view_requirements_fn(policy):
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
# VF preds are needed for the loss.
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
# Needed for postprocessing.
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
# Created during postprocessing.
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
}

View file

@ -1082,6 +1082,11 @@ class Trainer(Trainable):
raise ValueError(
"`_use_trajectory_view_api` only supported for PyTorch so "
"far!")
elif not config.get("_use_trajectory_view_api") and \
config.get("model", {}).get("_time_major"):
raise ValueError("`model._time_major` only supported "
"iff `_use_trajectory_view_api` is True!")
if "policy_graphs" in config["multiagent"]:
deprecation_warning("policy_graphs", "policies")
# Backwards compatibility.

View file

@ -10,7 +10,6 @@ import time
from typing import Union, Optional
import ray.cloudpickle as pickle
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
@ -337,6 +336,7 @@ def _create_embedded_rollout_worker(kwargs, send_fn):
real_env_creator = kwargs["env_creator"]
kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
rollout_worker = RolloutWorker(**kwargs)
inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
inference_thread.start()

View file

@ -1,7 +1,6 @@
import logging
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.per_policy_sample_collector import \
@ -16,6 +15,9 @@ from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.rllib.agents.callbacks import DefaultCallbacks
logger = logging.getLogger(__name__)
@ -38,7 +40,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
def __init__(
self,
policy_map: Dict[PolicyID, Policy],
callbacks: DefaultCallbacks,
callbacks: "DefaultCallbacks",
# TODO: (sven) make `num_agents` flexibly grow in size.
num_agents: int = 100,
num_timesteps=None,
@ -64,8 +66,8 @@ class _MultiAgentSampleCollector(_SampleCollector):
num_agents = 1000
self.num_agents = int(num_agents)
# Collect SampleBatches per-policy in PolicyTrajectories objects.
self.rollout_sample_collectors = {}
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
self.policy_sample_collectors = {}
for pid, policy in policy_map.items():
# Figure out max-shifts (before and after).
view_reqs = policy.training_view_requirements
@ -86,7 +88,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
elif num_timesteps is not None:
kwargs["num_timesteps"] = num_timesteps
self.rollout_sample_collectors[pid] = _PerPolicySampleCollector(
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
num_agents=self.num_agents,
shift_before=-max_shift_before,
shift_after=max_shift_after,
@ -109,7 +111,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
assert self.agent_to_policy[agent_id] == policy_id
# Add initial obs to Trajectory.
self.rollout_sample_collectors[policy_id].add_init_obs(
self.policy_sample_collectors[policy_id].add_init_obs(
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
@override(_SampleCollector)
@ -117,7 +119,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
values: Dict[str, TensorType]) -> None:
assert policy_id in self.rollout_sample_collectors
assert policy_id in self.policy_sample_collectors
# Make sure our mappings are up to date.
if agent_id not in self.agent_to_policy:
@ -130,13 +132,13 @@ class _MultiAgentSampleCollector(_SampleCollector):
values["agent_id"] = agent_id
# Add action/reward/next-obs (and other data) to Trajectory.
self.rollout_sample_collectors[policy_id].add_action_reward_next_obs(
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
episode_id, agent_id, env_id, agent_done, values)
@override(_SampleCollector)
def total_env_steps(self) -> int:
return sum(a.timesteps_since_last_reset
for a in self.rollout_sample_collectors.values())
for a in self.policy_sample_collectors.values())
def total(self):
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
@ -148,7 +150,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements
return self.rollout_sample_collectors[
return self.policy_sample_collectors[
policy_id].get_inference_input_dict(view_reqs)
@override(_SampleCollector)
@ -161,7 +163,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
# Loop through each per-policy collector and create a view (for each
# agent as SampleBatch) from its buffers for post-processing
all_agent_batches = {}
for pid, rc in self.rollout_sample_collectors.items():
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
agent_batches = rc.get_postprocessing_sample_batches(
@ -211,7 +213,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
@override(_SampleCollector)
def check_missing_dones(self, episode_id: EpisodeID) -> None:
for pid, rc in self.rollout_sample_collectors.items():
for pid, rc in self.policy_sample_collectors.items():
for agent_key in rc.agent_key_to_slot.keys():
# Only check for given episode and only for last chunk
# (all previous chunks for that agent in the episode are
@ -235,7 +237,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
def get_multi_agent_batch_and_reset(self):
self.postprocess_trajectories_so_far()
policy_batches = {}
for pid, rc in self.rollout_sample_collectors.items():
for pid, rc in self.policy_sample_collectors.items():
policy = self.policy_map[pid]
view_reqs = policy.training_view_requirements
policy_batches[pid] = rc.get_train_sample_batch_and_reset(

View file

@ -103,7 +103,13 @@ class _PerPolicySampleCollector:
self._next_agent_slot()
if SampleBatch.OBS not in self.buffers:
self._build_buffers(single_row={SampleBatch.OBS: init_obs})
self._build_buffers(
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
"env_id": env_id,
})
if self.time_major:
self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \
init_obs
@ -262,12 +268,12 @@ class _PerPolicySampleCollector:
batch = sample_batch_data[agent_key]
for view_col, view_req in view_reqs.items():
data_col = view_req.data_col or view_col
# Skip columns that will only get added through postprocessing
# (these may not even exist yet).
if view_req.created_during_postprocessing:
if data_col not in self.buffers:
continue
data_col = view_req.data_col or view_col
shift = view_req.shift
if data_col == SampleBatch.OBS:
shift -= 1
@ -289,20 +295,22 @@ class _PerPolicySampleCollector:
SampleBatch: Returns the accumulated sample batch for this
policy.
"""
seq_lens = [
seq_lens_w_0s = [
self.agent_key_to_timestep[k] - self.shift_before
for k in self.slot_to_agent_key if k is not None
]
first_zero_len = len(seq_lens)
if seq_lens[-1] == 0:
first_zero_len = seq_lens.index(0)
# We have an agent-axis buffer "rollover" (new SampleBatch will be
# built from last n agent records plus first m agent records in
# buffer).
if self.agent_slot_cursor < self.sample_batch_offset:
rollover = -(self.num_agents - self.sample_batch_offset)
seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover]
first_zero_len = len(seq_lens_w_0s)
if seq_lens_w_0s[-1] == 0:
first_zero_len = seq_lens_w_0s.index(0)
# Assert that all zeros lie at the end of the seq_lens array.
try:
assert all(seq_lens[i] == 0
for i in range(first_zero_len, len(seq_lens)))
except AssertionError as e:
print()
raise e
assert all(seq_lens_w_0s[i] == 0
for i in range(first_zero_len, len(seq_lens_w_0s)))
t_start = self.shift_before
t_end = t_start + self.num_timesteps
@ -311,8 +319,8 @@ class _PerPolicySampleCollector:
# actually already has at least 1 timestep of data (thus it excludes
# just-rolled over chunks (which only have the initial obs in them)).
valid_agent_cursor = \
(self.agent_slot_cursor - (len(seq_lens) - first_zero_len)) % \
self.num_agents
(self.agent_slot_cursor -
(len(seq_lens_w_0s) - first_zero_len)) % self.num_agents
# Construct the view dict.
view = {}
@ -320,12 +328,13 @@ class _PerPolicySampleCollector:
data_col = view_req.data_col or view_col
assert data_col in self.buffers
# For OBS, indices must be shifted by -1.
extra_shift = 0 if data_col != SampleBatch.OBS else -1
shift = view_req.shift
shift += 0 if data_col != SampleBatch.OBS else -1
# If agent_slot has been rolled-over to beginning, we have to copy
# here.
if valid_agent_cursor < self.sample_batch_offset:
time_slice = self.buffers[data_col][t_start + extra_shift:
t_end + extra_shift]
time_slice = self.buffers[data_col][t_start + shift:t_end +
shift]
one_ = time_slice[:, self.sample_batch_offset:]
two_ = time_slice[:, :valid_agent_cursor]
if torch and isinstance(time_slice, torch.Tensor):
@ -335,17 +344,15 @@ class _PerPolicySampleCollector:
else:
view[view_col] = \
self.buffers[data_col][
t_start + extra_shift:t_end + extra_shift,
t_start + shift:t_end + shift,
self.sample_batch_offset:valid_agent_cursor]
# Copy all still ongoing trajectories to new agent slots
# (including the ones that just started (are seq_len=0)).
new_chunk_args = []
for i, seq_len in enumerate(seq_lens):
for i, seq_len in enumerate(seq_lens_w_0s):
if seq_len < self.num_timesteps:
agent_slot = self.sample_batch_offset + i
if agent_slot >= self.num_agents:
agent_slot = agent_slot % self.num_agents
agent_slot = (self.sample_batch_offset + i) % self.num_agents
if not self.buffers[SampleBatch.
DONES][seq_len - 1 +
self.shift_before][agent_slot]:
@ -354,9 +361,9 @@ class _PerPolicySampleCollector:
(agent_slot, agent_key,
self.agent_key_to_timestep[agent_key]))
# Cut out all 0 seq-lens.
seq_lens = seq_lens[:first_zero_len]
seq_lens = seq_lens_w_0s[:first_zero_len]
batch = SampleBatch(
view, _seq_lens=np.array(seq_lens), _time_major=True)
view, _seq_lens=np.array(seq_lens), _time_major=self.time_major)
# Reset everything for new data.
self.postprocessed_agents = [False] * self.num_agents
@ -376,9 +383,14 @@ class _PerPolicySampleCollector:
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
"""Builds the internal data buffers based on a single given row.
This may be called several times in the lifetime of this instance
to add new columns to the buffer. Columns in `single_row` that already
exist in the buffer will be ignored.
Args:
single_row (Dict[str, TensorType]): A single datarow with one or
more columns (str as key, np.ndarray|tensor as data).
more columns (str as key, np.ndarray|tensor as data) to be used
as template to build the pre-allocated buffer.
"""
time_size = self.num_timesteps + self.shift_before + self.shift_after
for col, data in single_row.items():

View file

@ -515,7 +515,7 @@ class RolloutWorker(ParallelIteratorWorker):
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
pack_multiple_episodes_in_batch=pack,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
@ -538,7 +538,7 @@ class RolloutWorker(ParallelIteratorWorker):
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
pack_multiple_episodes_in_batch=pack,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon,

View file

@ -5,14 +5,17 @@ import numpy as np
import queue
import threading
import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, \
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
TYPE_CHECKING, Union
from ray.util.debug import log_once
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.multi_agent_sample_collector import \
_MultiAgentSampleCollector
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.sample_collector import _SampleCollector
from ray.rllib.policy.policy import clip_action, Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.models.preprocessors import Preprocessor
@ -22,6 +25,7 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \
unbatch
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@ -51,14 +55,23 @@ class _PerfStats:
def __init__(self):
self.iters = 0
self.env_wait_time = 0.0
self.processing_time = 0.0
self.raw_obs_processing_time = 0.0
self.inference_time = 0.0
self.action_processing_time = 0.0
def get(self):
# Mean multiplicator (1000 = ms -> sec).
factor = 1000 / self.iters
return {
"mean_env_wait_ms": self.env_wait_time * 1000 / self.iters,
"mean_processing_ms": self.processing_time * 1000 / self.iters,
"mean_inference_ms": self.inference_time * 1000 / self.iters
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * factor,
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time *
factor,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * factor,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * factor,
}
@ -108,7 +121,7 @@ class SyncSampler(SamplerInput):
rollout_fragment_length: int,
callbacks: "DefaultCallbacks",
horizon: int = None,
pack_multiple_episodes_in_batch: bool = False,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
soft_horizon: bool = False,
@ -136,7 +149,7 @@ class SyncSampler(SamplerInput):
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
@ -165,14 +178,20 @@ class SyncSampler(SamplerInput):
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats()
if _use_trajectory_view_api:
self.sample_collector = _MultiAgentSampleCollector(
policies, callbacks)
else:
self.sample_collector = None
# Create the rollout generator to use for calls to `get_data()`.
self.rollout_provider = _env_runner(
worker, self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack_multiple_episodes_in_batch, callbacks, tf_sess,
self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
_use_trajectory_view_api)
multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats,
soft_horizon, no_done_at_end, observation_fn,
_use_trajectory_view_api, self.sample_collector)
self.metrics_queue = queue.Queue()
@override(SamplerInput)
@ -226,7 +245,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
rollout_fragment_length: int,
callbacks: "DefaultCallbacks",
horizon: int = None,
pack_multiple_episodes_in_batch: bool = False,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
blackhole_outputs: bool = False,
@ -255,7 +274,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
@ -293,7 +312,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.obs_filters = obs_filters
self.clip_rewards = clip_rewards
self.daemon = True
self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.tf_sess = tf_sess
self.callbacks = callbacks
self.clip_actions = clip_actions
@ -304,6 +323,11 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.shutdown = False
self.observation_fn = observation_fn
self._use_trajectory_view_api = _use_trajectory_view_api
if _use_trajectory_view_api:
self.sample_collector = _MultiAgentSampleCollector(
policies, callbacks)
else:
self.sample_collector = None
@override(threading.Thread)
def run(self):
@ -325,8 +349,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.worker, self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack_multiple_episodes_in_batch,
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
self.tf_sess, self.perf_stats, self.soft_horizon,
self.no_done_at_end, self.observation_fn,
self._use_trajectory_view_api)
while not self.shutdown:
@ -385,14 +409,16 @@ def _env_runner(
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
clip_actions: bool,
pack_multiple_episodes_in_batch: bool,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
tf_sess: Optional["tf.Session"],
perf_stats: _PerfStats,
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
_use_trajectory_view_api: bool = False) -> Iterable[SampleBatchType]:
_use_trajectory_view_api: bool = False,
_sample_collector: Optional[_SampleCollector] = None,
) -> Iterable[SampleBatchType]:
"""This implements the common experience collection logic.
Args:
@ -413,7 +439,7 @@ def _env_runner(
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
clip_actions (bool): Whether to clip actions to the space range.
@ -430,6 +456,8 @@ def _env_runner(
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
_sample_collector (Optional[_SampleCollector]): An optional
_SampleCollector object to use
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@ -471,6 +499,8 @@ def _env_runner(
def get_batch_builder():
if batch_builder_pool:
return batch_builder_pool.pop()
elif _use_trajectory_view_api:
return None
else:
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
callbacks)
@ -495,6 +525,7 @@ def _env_runner(
return episode
active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode)
eval_results = None
while True:
perf_stats.iters += 1
@ -514,39 +545,73 @@ def _env_runner(
t1 = time.time()
# type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
# List[Union[RolloutMetrics, SampleBatchType]]
active_envs, to_eval, outputs = _process_observations(
worker=worker,
base_env=base_env,
policies=policies,
batch_builder_pool=batch_builder_pool,
active_episodes=active_episodes,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
rollout_fragment_length=rollout_fragment_length,
pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=_use_trajectory_view_api)
perf_stats.processing_time += time.time() - t1
if _use_trajectory_view_api:
active_envs, to_eval, outputs = \
_process_observations_w_trajectory_view_api(
worker=worker,
base_env=base_env,
policies=policies,
active_episodes=active_episodes,
prev_policy_outputs=eval_results,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
rollout_fragment_length=rollout_fragment_length,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
perf_stats=perf_stats,
_sample_collector=_sample_collector,
)
else:
active_envs, to_eval, outputs = _process_observations(
worker=worker,
base_env=base_env,
policies=policies,
batch_builder_pool=batch_builder_pool,
active_episodes=active_episodes,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
rollout_fragment_length=rollout_fragment_length,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
perf_stats=perf_stats,
)
perf_stats.raw_obs_processing_time += time.time() - t1
for o in outputs:
yield o
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
# type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = _do_policy_eval(
to_eval=to_eval,
policies=policies,
active_episodes=active_episodes,
tf_sess=tf_sess,
_use_trajectory_view_api=_use_trajectory_view_api)
if _use_trajectory_view_api:
eval_results = _do_policy_eval_w_trajectory_view_api(
to_eval=to_eval,
policies=policies,
_sample_collector=_sample_collector,
tf_sess=tf_sess,
)
else:
eval_results = _do_policy_eval(
to_eval=to_eval,
policies=policies,
active_episodes=active_episodes,
tf_sess=tf_sess,
)
perf_stats.inference_time += time.time() - t2
# Process results and update episode state.
@ -560,8 +625,10 @@ def _env_runner(
off_policy_actions=off_policy_actions,
policies=policies,
clip_actions=clip_actions,
_use_trajectory_view_api=_use_trajectory_view_api)
perf_stats.processing_time += time.time() - t3
_use_trajectory_view_api=_use_trajectory_view_api,
_sample_collector=_sample_collector,
)
perf_stats.action_processing_time += time.time() - t3
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
@ -571,6 +638,7 @@ def _env_runner(
def _process_observations(
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
@ -584,12 +652,12 @@ def _process_observations(
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
rollout_fragment_length: int,
pack_multiple_episodes_in_batch: bool,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
_use_trajectory_view_api: bool = False
perf_stats: _PerfStats,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]:
"""Record new data from the environment and prepare for policy evaluation.
@ -602,8 +670,11 @@ def _process_observations(
SampleBatchBuilder object for recycling.
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
prev_policy_outputs (Dict[str,List]): The prev policy output dict
(by policy-id -> List[action, state outs, extra fetches]).
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
call.
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
rewards tensor, returned by a `BaseEnv.poll()` call.
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
@ -618,7 +689,7 @@ def _process_observations(
rollout_fragment_length (int): Number of episode steps before
`SampleBatch` is yielded. Set to infinity to yield complete
episodes.
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
@ -628,9 +699,6 @@ def _process_observations(
and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent
observation func to use for preprocessing observations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
Returns:
Tuple:
@ -652,20 +720,21 @@ def _process_observations(
for env_id, agent_obs in unfiltered_obs.items():
is_new_episode: bool = env_id not in active_episodes
episode: MultiAgentEpisode = active_episodes[env_id]
batch_builder = episode.batch_builder
if not is_new_episode:
episode.length += 1
episode.batch_builder.count += 1
batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
if (episode.batch_builder.total() > large_batch_threshold
if (batch_builder.total() > large_batch_threshold
and log_once("large_batch_warning")):
logger.warning(
"More than {} observations for {} env steps ".format(
episode.batch_builder.total(),
episode.batch_builder.count) + "are buffered in "
"the sampler. If this is more than you expected, check "
"that you set a horizon on your environment correctly and "
"that it terminates at some point. "
batch_builder.total(), batch_builder.count) +
"are buffered in "
"the sampler. If this is more than you expected, check that "
"that you set a horizon on your environment correctly and that"
" it terminates at some point. "
"Note: In multi-agent environments, `rollout_fragment_length` "
"sets the batch size based on environment steps, not the "
"steps of "
@ -725,12 +794,12 @@ def _process_observations(
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
infos[env_id].get(agent_id, {}),
episode.rnn_state_for(agent_id),
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0))
item = PolicyEvalData(env_id, agent_id, filtered_obs,
infos[env_id].get(agent_id, {}),
episode.rnn_state_for(agent_id),
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0)
to_eval[policy_id].append(item)
last_observation: EnvObsType = episode.last_observation_for(
agent_id)
@ -741,7 +810,7 @@ def _process_observations(
# Record transition info if applicable.
if (last_observation is not None and infos[env_id].get(
agent_id, {}).get("training_enabled", True)):
episode.batch_builder.add_values(
batch_builder.add_values(
agent_id,
policy_id,
t=episode.length - 1,
@ -767,26 +836,26 @@ def _process_observations(
# - all-agents-done and not packing multiple episodes into one
# (batch_mode="complete_episodes")
# - or if we've exceeded the rollout_fragment_length.
if episode.batch_builder.has_pending_agent_data():
if batch_builder.has_pending_agent_data():
# Sanity check, whether all agents have done=True, if done[__all__]
# is True.
if dones[env_id]["__all__"] and not no_done_at_end:
episode.batch_builder.check_missing_dones()
batch_builder.check_missing_dones()
# Reached end of episode and we are not allowed to pack the
# next episode into the same SampleBatch -> Build the SampleBatch
# and add it to "outputs".
if (all_agents_done and not pack_multiple_episodes_in_batch) or \
episode.batch_builder.count >= rollout_fragment_length:
outputs.append(episode.batch_builder.build_and_reset(episode))
# Make sure postprocessor stays within one episode.
elif all_agents_done:
episode.batch_builder.postprocess_batch_so_far(episode)
# Reached end of episode and we are not allowed to pack the
# next episode into the same SampleBatch -> Build the SampleBatch
# and add it to "outputs".
if (all_agents_done and not multiple_episodes_in_batch) or \
batch_builder.count >= rollout_fragment_length:
outputs.append(batch_builder.build_and_reset(episode))
# Make sure postprocessor stays within one episode.
elif all_agents_done:
batch_builder.postprocess_batch_so_far(episode)
# Episode is done.
if all_agents_done:
# Handle episode termination.
batch_builder_pool.append(episode.batch_builder)
# We can pass the BatchBuilder to recycling.
batch_builder_pool.append(batch_builder)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
if getattr(p, "exploration", None) is not None:
@ -834,14 +903,262 @@ def _process_observations(
filtered_obs: EnvObsType = _get_or_raise(
obs_filters, policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id),
np.zeros_like(
flatten_to_single_ndarray(
policy.action_space.sample())), 0.0))
item = PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id),
np.zeros_like(
flatten_to_single_ndarray(
policy.action_space.sample())), 0.0)
to_eval[policy_id].append(item)
return active_envs, to_eval, outputs
def _process_observations_w_trajectory_view_api(
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
active_episodes: Dict[str, MultiAgentEpisode],
prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
dict]],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]],
dones: Dict[EnvID, Dict[AgentID, bool]],
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
horizon: int,
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
rollout_fragment_length: int,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
perf_stats: _PerfStats,
_sample_collector: _SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]:
"""Trajectory View API version of `_process_observations()`.
TODO: (sven) Move docstring here once original function is deprecated.
"""
# Output objects.
active_envs: Set[EnvID] = set()
to_eval: Set[PolicyID] = set()
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
rollout_fragment_length != float("inf") else 5000
# For each environment.
# type: EnvID, Dict[AgentID, EnvObsType]
for env_id, agent_obs in unfiltered_obs.items():
is_new_episode: bool = env_id not in active_episodes
episode: MultiAgentEpisode = active_episodes[env_id]
if not is_new_episode:
episode.length += 1
_sample_collector.count += 1
episode._add_agent_rewards(rewards[env_id])
if (_sample_collector.total_env_steps() > large_batch_threshold
and log_once("large_batch_warning")):
logger.warning(
"More than {} observations for {} env steps ".format(
_sample_collector.total_env_steps(),
_sample_collector.count) + "are buffered in "
"the sampler. If this is more than you expected, check that "
"that you set a horizon on your environment correctly and that"
" it terminates at some point. "
"Note: In multi-agent environments, `rollout_fragment_length` "
"sets the batch size based on environment steps, not the "
"steps of "
"individual agents, which can result in unexpectedly large "
"batches. Also, you may be in evaluation waiting for your Env "
"to terminate (batch_mode=`complete_episodes`). Make sure it "
"does at some point.")
# Check episode termination conditions.
if dones[env_id]["__all__"] or episode.length >= horizon:
hit_horizon = (episode.length >= horizon
and not dones[env_id]["__all__"])
all_agents_done = True
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
base_env)
if atari_metrics is not None:
for m in atari_metrics:
outputs.append(
m._replace(custom_metrics=episode.custom_metrics))
else:
outputs.append(
RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics, {},
episode.hist_data))
else:
hit_horizon = False
all_agents_done = False
active_envs.add(env_id)
# Custom observation function is applied before preprocessing.
if observation_fn:
agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=agent_obs,
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
if not isinstance(agent_obs, dict):
raise ValueError(
"observe() must return a dict of agent observations")
# For each agent in the environment.
# type: AgentID, EnvObsType
for agent_id, raw_obs in agent_obs.items():
assert agent_id != "__all__"
policy_id: PolicyID = episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(preprocessors,
policy_id).transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
filtered_obs: EnvObsType = _get_or_raise(obs_filters,
policy_id)(prep_obs)
if log_once("filtered_obs"):
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
last_observation: EnvObsType = episode.last_observation_for(
agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode.episode_id, agent_id,
env_id, policy_id, filtered_obs)
else:
rc = _sample_collector.policy_sample_collectors[policy_id]
eval_idx = rc.agent_key_to_forward_pass_index[(
agent_id, episode.episode_id)]
values_dict = {
"t": episode.length - 1,
"eps_id": episode.episode_id,
"agent_index": episode._agent_index(agent_id),
# Action (slot 0) taken at timestep t.
"actions": prev_policy_outputs[policy_id][0][eval_idx],
# Reward received after taking a at timestep t.
"rewards": rewards[env_id][agent_id],
# After taking a, did we reach terminal?
"dones": (False if (no_done_at_end
or (hit_horizon and soft_horizon)) else
agent_done),
# Next observation.
"new_obs": filtered_obs,
}
# TODO: (sven) add env infos to buffers as well.
for k, v in prev_policy_outputs[policy_id][2].items():
values_dict[k] = v[eval_idx]
for i, v in enumerate(prev_policy_outputs[policy_id][1]):
values_dict["state_out_{}".format(i)] = v[eval_idx]
_sample_collector.add_action_reward_next_obs(
episode.episode_id, agent_id, env_id, policy_id,
agent_done, values_dict)
if not agent_done:
to_eval.add(policy_id)
# Invoke the step callback after the step is logged to the episode
callbacks.on_episode_step(
worker=worker, base_env=base_env, episode=episode)
# Cut the batch if ...
# - all-agents-done and not packing multiple episodes into one
# (batch_mode="complete_episodes")
# - or if we've exceeded the rollout_fragment_length.
if _sample_collector.has_non_postprocessed_data():
# Sanity check, whether all agents have done=True, if done[__all__]
# is True.
if dones[env_id]["__all__"] and not no_done_at_end:
_sample_collector.check_missing_dones(
episode_id=episode.episode_id)
# Reached end of episode and we are not allowed to pack the
# next episode into the same SampleBatch -> Build the SampleBatch
# and add it to "outputs".
if (all_agents_done and not multiple_episodes_in_batch) or \
_sample_collector.count >= rollout_fragment_length:
# TODO: (sven) Case: rollout_fragment_length reached: Do not
# store any data in `episode` anymore
# (useless for get_view_requirements when t<<-1, e.g.
# attention), but keep last episode data around in
# SampleBatchBuilder
# to be able to still reference into it
# should a model require this.
outputs.append(_sample_collector.get_multi_agent_batch_and_reset())
# Make sure postprocessor stays within one episode.
elif all_agents_done:
_sample_collector.postprocess_trajectories_so_far(episode)
# Episode is done.
if all_agents_done:
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_end callback.
callbacks.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
else:
del active_episodes[env_id]
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list.
if horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
elif resetted_obs != ASYNC_RESET_RETURN:
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll
episode: MultiAgentEpisode = active_episodes[env_id]
if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs,
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
# type: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(
preprocessors, policy_id).transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
obs_filters, policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
# Add initial obs to buffer.
_sample_collector.add_init_obs(episode.episode_id,
agent_id, env_id, policy_id,
filtered_obs)
to_eval.add(policy_id)
return active_envs, to_eval, outputs
@ -852,7 +1169,6 @@ def _do_policy_eval(
policies: Dict[PolicyID, Policy],
active_episodes: Dict[str, MultiAgentEpisode],
tf_sess=None,
_use_trajectory_view_api=False
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action.
@ -866,9 +1182,6 @@ def _do_policy_eval(
episode ID to currently ongoing MultiAgentEpisode object.
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` procedure to collect samples.
Default: False.
Returns:
eval_results: dict of policy to compute_action() outputs.
@ -888,15 +1201,15 @@ def _do_policy_eval(
# type: PolicyID, PolicyEvalData
for policy_id, eval_data in to_eval.items():
rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data]
policy: Policy = _get_or_raise(policies, policy_id)
# If tf (non eager) AND TFPolicy's compute_action method has not been
# overridden -> Use `policy._build_compute_actions()`.
# If tf (non eager) AND TFPolicy's compute_action method has not
# been overridden -> Use `policy._build_compute_actions()`.
if builder and (policy.compute_actions.__code__ is
TFPolicy.compute_actions.__code__):
obs_batch: List[EnvObsType] = [t.obs for t in eval_data]
state_batches: StateBatch = _to_column_format(rnn_in)
state_batches: StateBatch = _to_column_format(
[t.rnn_state for t in eval_data])
# TODO(ekl): how can we make info batch available to TF code?
prev_action_batch = [t.prev_action for t in eval_data]
prev_reward_batch = [t.prev_reward for t in eval_data]
@ -909,6 +1222,7 @@ def _do_policy_eval(
prev_reward_batch=prev_reward_batch,
timestep=policy.global_timestep)
else:
rnn_in = [t.rnn_state for t in eval_data]
rnn_in_cols: StateBatch = [
np.stack([row[i] for row in rnn_in])
for i in range(len(rnn_in[0]))
@ -921,6 +1235,61 @@ def _do_policy_eval(
info_batch=[t.info for t in eval_data],
episodes=[active_episodes[t.env_id] for t in eval_data],
timestep=policy.global_timestep)
if builder:
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
for pid, v in pending_fetches.items():
eval_results[pid] = builder.get(v)
if log_once("compute_actions_result"):
logger.info("Outputs of compute_actions():\n\n{}\n".format(
summarize(eval_results)))
return eval_results
def _do_policy_eval_w_trajectory_view_api(
*,
to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: Dict[PolicyID, Policy],
_sample_collector,
tf_sess=None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action.
Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
IDs to lists of PolicyEvalData objects (items in these lists will
be the batch's items for the model forward pass).
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
obj.
_sample_collector (SampleCollector): The SampleCollector object to use.
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations.
Returns:
eval_results: dict of policy to compute_action() outputs.
"""
eval_results: Dict[PolicyID, TensorStructType] = {}
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
pending_fetches: Dict[PolicyID, Any] = {}
else:
builder = None
if log_once("compute_actions_input"):
logger.info("Inputs to compute_actions():\n\n{}\n".format(
summarize(to_eval)))
for policy_id in to_eval:
policy: Policy = _get_or_raise(policies, policy_id)
input_dict = _sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = \
policy.compute_actions_from_input_dict(
input_dict, timestep=policy.global_timestep)
if builder:
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
for pid, v in pending_fetches.items():
@ -943,7 +1312,8 @@ def _process_policy_eval_results(
off_policy_actions: MultiEnvDict,
policies: Dict[PolicyID, Policy],
clip_actions: bool,
_use_trajectory_view_api: bool = False
_use_trajectory_view_api: bool = False,
_sample_collector=None,
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
"""Process the output of policy neural network evaluation.
@ -980,11 +1350,10 @@ def _process_policy_eval_results(
actions_to_send[env_id] = {} # at minimum send empty dict
# type: PolicyID, List[PolicyEvalData]
for policy_id, eval_data in to_eval.items():
rnn_in_cols: StateBatch = _to_column_format(
[t.rnn_state for t in eval_data])
for policy_id in to_eval:
actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions)
rnn_out_cols: StateBatch = eval_results[policy_id][1]
pi_info_cols: dict = eval_results[policy_id][2]
@ -993,40 +1362,58 @@ def _process_policy_eval_results(
if isinstance(actions, list):
actions = np.array(actions)
if len(rnn_in_cols) != len(rnn_out_cols):
raise ValueError("Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Add RNN state info.
eval_data = None
if not _use_trajectory_view_api:
eval_data = to_eval[policy_id]
rnn_in_cols: StateBatch = _to_column_format(
[t.rnn_state for t in eval_data])
if len(rnn_in_cols) != len(rnn_out_cols):
raise ValueError(
"Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
for f_i, column in enumerate(rnn_in_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
policy: Policy = _get_or_raise(policies, policy_id)
# Split action-component batches into single action rows.
actions: List[EnvActionType] = unbatch(actions)
# type: int, EnvActionType
for i, action in enumerate(actions):
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
# Clip if necessary.
if clip_actions:
clipped_action = clip_action(action,
policy.action_space_struct)
else:
clipped_action = action
actions_to_send[env_id][agent_id] = clipped_action
episode: MultiAgentEpisode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id,
off_policy_actions[env_id][agent_id])
# Trajectory View API: Do not store data directly in episode
# (entire episode is stored in Trajectory and kept until
# end of episode).
if _use_trajectory_view_api:
agent_id, episode_id, env_id = \
_sample_collector.policy_sample_collectors[
policy_id].forward_pass_index_to_agent_info[i]
else:
episode._set_last_action(agent_id, action)
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
episode: MultiAgentEpisode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)
assert agent_id not in actions_to_send[env_id]
actions_to_send[env_id][agent_id] = clipped_action
return actions_to_send
@ -1054,20 +1441,21 @@ def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
def _get_or_raise(mapping: Dict[PolicyID, Policy],
policy_id: PolicyID) -> Policy:
"""Returns a Policy object under key `policy_id` in `mapping`.
def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]:
"""Returns an object under key `policy_id` in `mapping`.
Args:
mapping (dict): The mapping dict from policy id (str) to
actual Policy object.
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
mapping dict from policy id (str) to actual object (Policy,
Preprocessor, etc.).
policy_id (str): The policy ID to lookup.
Returns:
Policy: The found Policy object.
Union[Policy, Preprocessor, Filter]: The found object.
Throws:
ValueError: If `policy_id` cannot be found.
ValueError: If `policy_id` cannot be found in `mapping`.
"""
if policy_id not in mapping:
raise ValueError(

View file

@ -0,0 +1,277 @@
import copy
from gym.spaces import Box, Discrete
import time
import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwarePolicy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import framework_iterator
class TestTrajectoryViewAPI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_traj_view_normal_case(self):
"""Tests, whether Model and Policy return the correct ViewRequirements.
"""
config = ppo.DEFAULT_CONFIG.copy()
for _ in framework_iterator(config, frameworks="torch"):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.training_view_requirements
assert len(view_req_model) == 1
assert len(view_req_policy) == 10
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
SampleBatch.VF_PREDS, "advantages", "value_targets",
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
]:
assert key in view_req_policy
# None of the view cols has a special underlying data_col,
# except next-obs.
if key != SampleBatch.NEXT_OBS:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
trainer.stop()
def test_traj_view_lstm_prev_actions_and_rewards(self):
"""Tests, whether Policy/Model return correct LSTM ViewRequirements.
"""
config = ppo.DEFAULT_CONFIG.copy()
config["model"] = config["model"].copy()
# Activate LSTM + prev-action + rewards.
config["model"]["use_lstm"] = True
config["model"]["lstm_use_prev_action_reward"] = True
for _ in framework_iterator(config, frameworks="torch"):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.training_view_requirements
assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates
assert len(view_req_policy) == 16
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS, "advantages", "value_targets",
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
]:
assert key in view_req_policy
if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
trainer.stop()
def test_traj_view_lstm_performance(self):
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
"""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
action_space = Discrete(2)
obs_space = Box(-1.0, 1.0, shape=(700, ))
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
from ray.tune import register_env
register_env("ma_env", lambda c: RandomMultiAgentEnv({
"num_agents": 2,
"p_done": 0.01,
"action_space": action_space,
"observation_space": obs_space
}))
config["num_workers"] = 3
config["num_envs_per_worker"] = 8
config["num_sgd_iter"] = 6
config["model"]["use_lstm"] = True
config["model"]["lstm_use_prev_action_reward"] = True
config["model"]["max_seq_len"] = 100
policies = {
"pol0": (None, obs_space, action_space, {}),
}
def policy_fn(agent_id):
return "pol0"
config["multiagent"] = {
"policies": policies,
"policy_mapping_fn": policy_fn,
}
num_iterations = 1
# Only works in torch so far.
for _ in framework_iterator(config, frameworks="torch"):
print("w/ traj. view API (and time-major)")
config["_use_trajectory_view_api"] = True
config["model"]["_time_major"] = True
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_w = 0.0
sampler_perf = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
sampler_perf_ = out["sampler_perf"]
sampler_perf = {
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / 1000
learn_time_w += delta
print("{}={}s".format(i, delta))
sampler_perf = {
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf.items()
}
duration_w = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_w, sampler_perf, learn_time_w / num_iterations))
trainer.stop()
print("w/o traj. view API (and w/o time-major)")
config["_use_trajectory_view_api"] = False
config["model"]["_time_major"] = False
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_wo = 0.0
sampler_perf = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
sampler_perf_ = out["sampler_perf"]
sampler_perf = {
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / 1000
learn_time_wo += delta
print("{}={}s".format(i, delta))
sampler_perf = {
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf.items()
}
duration_wo = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_wo, sampler_perf,
learn_time_wo / num_iterations))
trainer.stop()
# Assert `_use_trajectory_view_api` is much faster.
self.assertLess(duration_w, duration_wo)
self.assertLess(learn_time_w, learn_time_wo * 0.6)
def test_traj_view_lstm_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(2, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
policies = {
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
}
def policy_fn(agent_id):
return "pol0"
rollout_worker = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config={
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_fn,
},
"_use_trajectory_view_api": True,
"model": {
"use_lstm": True,
"_time_major": True,
"max_seq_len": max_seq_len,
},
},
policy=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
for i in range(100):
pc = rollout_worker.sampler.sample_collector. \
policy_sample_collectors["pol0"]
sample_batch_offset_before = pc.sample_batch_offset
buffers = pc.buffers
result = rollout_worker.sample()
pol_batch = result.policy_batches["pol0"]
self.assertTrue(result.count == 100)
self.assertTrue(pol_batch.count >= 100)
self.assertFalse(0 in pol_batch.seq_lens)
# Check prev_reward/action, next_obs consistency.
for t in range(max_seq_len):
obs_t = pol_batch["obs"][t]
r_t = pol_batch["rewards"][t]
if t > 0:
next_obs_t_m_1 = pol_batch["new_obs"][t - 1]
self.assertTrue((obs_t == next_obs_t_m_1).all())
if t < max_seq_len - 1:
prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1]
self.assertTrue((r_t == prev_rewards_t_p_1).all())
# Check the sanity of all the buffers in the un underlying
# PerPolicy collector.
for sample_batch_slot, agent_slot in enumerate(
range(sample_batch_offset_before, pc.sample_batch_offset)):
t_buf = buffers["t"][:, agent_slot]
obs_buf = buffers["obs"][:, agent_slot]
# Skip empty seqs at end (these won't be part of the batch
# and have been copied to new agent-slots (even if seq-len=0)).
if sample_batch_slot < len(pol_batch.seq_lens):
seq_len = pol_batch.seq_lens[sample_batch_slot]
# Make sure timesteps are always increasing within the seq.
assert all(t_buf[1] + j == n + 1
for j, n in enumerate(t_buf)
if j < seq_len and j != 0)
# Make sure all obs within seq are non-0.0.
assert all(
any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1))
# Check seq-lens.
for agent_slot, seq_len in enumerate(pol_batch.seq_lens):
if seq_len < max_seq_len - 1:
# At least in the beginning, the next slots should always
# be empty (once all agent slots have been used once, these
# may be filled with "old" values (from longer sequences)).
if i < 10:
self.assertTrue(
(pol_batch["obs"][seq_len +
1][agent_slot] == 0.0).all())
print(end="")
self.assertFalse(
(pol_batch["obs"][seq_len][agent_slot] == 0.0).all())
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,4 +1,7 @@
import gym
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
class DebugCounterEnv(gym.Env):
@ -21,3 +24,43 @@ class DebugCounterEnv(gym.Env):
def step(self, action):
self.i += 1
return [self.i], self.i % 3, self.i >= 15, {}
class MultiAgentDebugCounterEnv(MultiAgentEnv):
def __init__(self, config):
self.num_agents = config["num_agents"]
self.p_done = config.get("p_done", 0.02)
# Actions are always:
# (episodeID, envID) as floats.
self.action_space = \
gym.spaces.Box(-float("inf"), float("inf"), shape=(2, ))
# Observation dims:
# 0=agent ID.
# 1=episode ID (0.0 for obs after reset).
# 2=env ID (0.0 for obs after reset).
# 3=ts (of the agent).
self.observation_space = \
gym.spaces.Box(float("-inf"), float("inf"), (4, ))
self.timesteps = [0] * self.num_agents
self.dones = set()
def reset(self):
self.dones = set()
return {
i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
for i in range(self.num_agents)
}
def step(self, action_dict):
obs, rew, done = {}, {}, {}
for i, action in action_dict.items():
self.timesteps[i] += 1
obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
rew[i] = self.timesteps[i] % 3
done[i] = bool(
np.random.choice(
[True, False], p=[self.p_done, 1.0 - self.p_done]))
if done[i]:
self.dones.add(i)
done["__all__"] = len(self.dones) == self.num_agents
return obs, rew, done, {}

View file

@ -0,0 +1,66 @@
import numpy as np
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
class EpisodeEnvAwarePolicy(RandomPolicy):
"""A Policy that always knows the current EpisodeID and EnvID and
returns these in its actions."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.episode_id = None
self.env_id = None
class _fake_model:
pass
self.model = _fake_model()
self.model.time_major = True
self.model.inference_view_requirements = {
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
}
self.training_view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
},
**self.model.inference_view_requirements)
@override(Policy)
def is_recurrent(self):
return True
@override(Policy)
def compute_actions_from_input_dict(self,
input_dict,
explore=None,
timestep=None,
**kwargs):
self.episode_id = input_dict[SampleBatch.EPS_ID][0]
self.env_id = input_dict["env_id"][0]
# Always return (episodeID, envID)
return [
np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"]
], [], {}
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0
return sample_batch

View file

@ -65,6 +65,9 @@ MODEL_DEFAULTS: ModelConfigDict = {
"lstm_cell_size": 256,
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
"lstm_use_prev_action_reward": False,
# Experimental (only works with `_use_trajectory_view_api`=True):
# Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
"_time_major": False,
# When using modelv1 models with a modelv2 algorithm, you may have to
# define the state shape here (e.g., [256, 256]).
"state_shape": None,

View file

@ -58,6 +58,10 @@ class ModelV2:
self.name: str = name or "default_model"
self.framework: str = framework
self._last_output = None
self.time_major = self.model_config.get("_time_major")
self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0),
}
@PublicAPI
def get_initial_state(self) -> List[np.ndarray]:
@ -246,26 +250,6 @@ class ModelV2:
i += 1
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def inference_view_requirements(self) -> Dict[str, ViewRequirement]:
"""Returns a dict of ViewRequirements for this Model.
Note: This is an experimental API method.
The view requirements dict is used to generate input_dicts and
train batches for 1) action computations, 2) postprocessing, and 3)
generating training batches.
Returns:
Dict[str, ViewRequirement]: The view requirements dict, mapping
each view key (which will be available in input_dicts) to
an underlying requirement (actual data, timestep shift, etc..).
"""
# Default implementation for simple RL model:
# Single requirement: Pass current obs as input.
return {
SampleBatch.OBS: ViewRequirement(shift=0),
}
def import_from_h5(self, h5_file: str) -> None:
"""Imports weights from an h5 file.
@ -322,6 +306,16 @@ class ModelV2:
"""
raise NotImplementedError
@PublicAPI
def is_time_major(self) -> bool:
"""If True, data for calling this ModelV2 must be in time-major format.
Returns
bool: Whether this ModelV2 requires a time-major (TxBx...) data
format.
"""
return self.time_major is True
class NullContextManager:
"""No-op context manager"""

View file

@ -54,10 +54,11 @@ class RecurrentNetwork(TFModelV2):
You should implement forward_rnn() in your subclass."""
assert seq_lens is not None
padded_inputs = input_dict["obs_flat"]
max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0]
output, new_state = self.forward_rnn(
add_time_dimension(
input_dict["obs_flat"], seq_lens, framework="tf"), state,
padded_inputs, max_seq_len=max_seq_len, framework="tf"), state,
seq_lens)
return tf.reshape(output, [-1, self.num_outputs]), new_state

View file

@ -1,5 +1,5 @@
from gym.spaces import Box
import numpy as np
from typing import Dict
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
@ -63,13 +63,20 @@ class RecurrentNetwork(TorchModelV2):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
flat_inputs = input_dict["obs_flat"].float()
if isinstance(seq_lens, np.ndarray):
seq_lens = torch.Tensor(seq_lens).int()
output, new_state = self.forward_rnn(
add_time_dimension(
input_dict["obs_flat"].float(), seq_lens, framework="torch"),
state, seq_lens)
return torch.reshape(output, [-1, self.num_outputs]), new_state
max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
self.time_major = self.model_config.get("_time_major", False)
inputs = add_time_dimension(
flat_inputs,
max_seq_len=max_seq_len,
framework="torch",
time_major=self.time_major,
)
output, new_state = self.forward_rnn(inputs, state, seq_lens)
output = torch.reshape(output, [-1, self.num_outputs])
return output, new_state
def forward_rnn(self, inputs, state, seq_lens):
"""Call the model with the given input tensors and state.
@ -104,13 +111,15 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
super().__init__(obs_space, action_space, None, model_config, name)
self.cell_size = model_config["lstm_cell_size"]
self.time_major = model_config.get("_time_major", False)
self.use_prev_action_reward = model_config[
"lstm_use_prev_action_reward"]
self.action_dim = int(np.product(action_space.shape))
# Add prev-action/reward nodes to input to LSTM.
if self.use_prev_action_reward:
self.num_outputs += 1 + self.action_dim
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)
self.lstm = nn.LSTM(
self.num_outputs, self.cell_size, batch_first=not self.time_major)
self.num_outputs = num_outputs
@ -126,6 +135,26 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_)
self.inference_view_requirements.update(
dict(
**{
SampleBatch.OBS: ViewRequirement(shift=0),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space,
shift=-1),
}))
for i in range(2):
self.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
self.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
@override(RecurrentNetwork)
def forward(self, input_dict, state, seq_lens):
assert seq_lens is not None
@ -150,10 +179,24 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
# Don't show paddings to RNN(?)
# TODO: (sven) For now, only allow, iff time_major=True to not break
# anything retrospectively (time_major not supported previously).
# max_seq_len = inputs.shape[0]
# time_major = self.model_config["_time_major"]
# if time_major and max_seq_len > 1:
# inputs = torch.nn.utils.rnn.pack_padded_sequence(
# inputs, seq_lens,
# batch_first=not time_major, enforce_sorted=False)
self._features, [h, c] = self.lstm(
inputs,
[torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
# Re-apply paddings.
# if time_major and max_seq_len > 1:
# self._features, _ = torch.nn.utils.rnn.pad_packed_sequence(
# self._features,
# batch_first=not time_major)
model_out = self._logits_branch(self._features)
return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
@ -171,16 +214,3 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
def value_function(self):
assert self._features is not None, "must call forward() first"
return torch.reshape(self._value_branch(self._features), [-1])
@override(ModelV2)
def inference_view_requirements(self) -> Dict[str, ViewRequirement]:
req = super().inference_view_requirements()
# Optional: prev-actions/rewards for forward pass.
if self.model_config["lstm_use_prev_action_reward"]:
req.update({
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
})
return req

View file

@ -70,6 +70,11 @@ class Policy(metaclass=ABCMeta):
# The action distribution class to use for action sampling, if any.
# Child classes may set this.
self.dist_class = None
# View requirements dict for a `learn_on_batch()` call.
# Child classes need to add their specific requirements here (usually
# a combination of a Model's inference_view_- and the
# Policy's loss function-requirements.
self.training_view_requirements = {}
@abstractmethod
@DeveloperAPI
@ -283,25 +288,6 @@ class Policy(metaclass=ABCMeta):
"""
raise NotImplementedError
@DeveloperAPI
def training_view_requirements(self):
"""Returns a dict of view requirements for operating on this Policy.
Note: This is an experimental API method.
The view requirements dict is used to generate input_dicts and
SampleBatches for 1) action computations, 2) postprocessing, and 3)
generating training batches.
The Policy may ask its Model(s) as well for possible additional
requirements (e.g. prev-action/reward in an LSTM).
Returns:
Dict[str, ViewRequirement]: The view requirements dict, mapping
each view key (which will be available in input_dicts) to
an underlying requirement (actual data, timestep shift, etc..).
"""
return {}
@DeveloperAPI
def postprocess_trajectory(
self,

View file

@ -13,12 +13,14 @@ current algorithms: https://github.com/ray-project/ray/issues/2992
import logging
import numpy as np
from typing import List, Optional
from ray.util import log_once
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.typing import TensorType
from ray.util import log_once
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -27,11 +29,14 @@ logger = logging.getLogger(__name__)
@DeveloperAPI
def pad_batch_to_sequences_of_same_size(batch,
max_seq_len,
shuffle=False,
batch_divisibility_req=1,
feature_keys=None):
def pad_batch_to_sequences_of_same_size(
batch: SampleBatch,
max_seq_len: int,
shuffle: bool = False,
batch_divisibility_req: int = 1,
feature_keys: Optional[List[str]] = None,
_use_trajectory_view_api: bool = False,
):
"""Applies padding to `batch` so it's choppable into same-size sequences.
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
@ -51,7 +56,26 @@ def pad_batch_to_sequences_of_same_size(batch,
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
_use_trajectory_view_api (bool): Whether we are using the Trajectory
View API to collect and process samples.
"""
if _use_trajectory_view_api:
if batch.time_major is not None:
batch["seq_lens"] = torch.tensor(batch.seq_lens)
t = 0 if batch.time_major else 1
for col in batch.data.keys():
# Cut time-dim from states.
if "state_" in col[:6]:
batch[col] = batch[col][t]
# Flatten all other data.
else:
# Cut time-dim at `max_seq_len`.
if batch.time_major:
batch[col] = batch[col][:batch.max_seq_len]
batch[col] = batch[col].reshape((-1, ) +
batch[col].shape[2:])
return
if batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
@ -61,7 +85,7 @@ def pad_batch_to_sequences_of_same_size(batch,
meets_divisibility_reqs = True
# RNN-case.
if "state_in_0" in batch:
if "state_in_0" in batch or "state_out_0" in batch:
dynamic_max = True
# Multi-agent case.
elif not meets_divisibility_reqs:
@ -109,31 +133,32 @@ def pad_batch_to_sequences_of_same_size(batch,
@DeveloperAPI
def add_time_dimension(padded_inputs,
seq_lens,
framework="tf",
time_major=False):
def add_time_dimension(padded_inputs: TensorType,
*,
max_seq_len: int,
framework: str = "tf",
time_major: bool = False):
"""Adds a time dimension to padded inputs.
Arguments:
padded_inputs (Tensor): a padded batch of sequences. That is,
Args:
padded_inputs (TensorType): a padded batch of sequences. That is,
for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
A, B, C are sequence elements and * denotes padding.
seq_lens (Tensor): the sequence lengths within the input batch,
suitable for passing to tf.nn.dynamic_rnn().
max_seq_len (int): The max. sequence length in padded_inputs.
framework (str): The framework string ("tf2", "tf", "tfe", "torch").
time_major (bool): Whether data should be returned in time-major (TxB)
format or not (BxT).
Returns:
Reshaped tensor of shape [NUM_SEQUENCES, MAX_SEQ_LEN, ...].
TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
"""
# Sequence lengths have to be specified for LSTM batch inputs. The
# input batch must be padded to the max seq length given here. That is,
# batch_size == len(seq_lens) * max(seq_lens)
if framework == "tf":
if framework in ["tf2", "tf", "tfe"]:
assert time_major is False, "time-major not supported yet for tf!"
padded_batch_size = tf.shape(padded_inputs)[0]
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = ([new_batch_size, max_seq_len] +
@ -142,7 +167,6 @@ def add_time_dimension(padded_inputs,
else:
assert framework == "torch", "`framework` must be either tf or torch!"
padded_batch_size = padded_inputs.shape[0]
max_seq_len = padded_batch_size // seq_lens.shape[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
@ -153,6 +177,9 @@ def add_time_dimension(padded_inputs,
return torch.reshape(padded_inputs, new_shape)
# NOTE: This function will be deprecated once chunks already come padded and
# correctly chopped from the _SampleCollector object (in time-major fashion
# or not). It is already no longer user iff `_use_trajectory_view_api` = True.
@DeveloperAPI
def chop_into_sequences(episode_ids,
unroll_ids,
@ -166,11 +193,11 @@ def chop_into_sequences(episode_ids,
"""Truncate and pad experiences into fixed-length sequences.
Args:
episode_ids (list): List of episode ids for each step.
unroll_ids (list): List of identifiers for the sample batch. This is
used to make sure sequences are cut between sample batches.
agent_indices (list): List of agent ids for each step. Note that this
has to be combined with episode_ids for uniqueness.
episode_ids (List[EpisodeID]): List of episode ids for each step.
unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
This is used to make sure sequences are cut between sample batches.
agent_indices (List[AgentID]): List of agent ids for each step. Note
that this has to be combined with episode_ids for uniqueness.
feature_columns (list): List of arrays containing features.
state_columns (list): List of arrays containing LSTM state values.
max_seq_len (int): Max length of sequences before truncation.

View file

@ -59,19 +59,33 @@ class SampleBatch:
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
self._initial_inputs = kwargs.pop("_initial_inputs", {})
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", None)
self.max_seq_len = None
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.max_seq_len = max(self.seq_lens)
# The actual data, accessible by column name (str).
self.data = dict(*args, **kwargs)
lengths = []
for k, v in self.data.copy().items():
assert isinstance(k, str), self
lengths.append(len(v))
self.data[k] = np.array(v, copy=False)
if isinstance(v, list):
self.data[k] = np.array(v)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, ("data columns must be same length",
self.data, lengths)
self.count = lengths[0]
assert len(set(lengths)) == 1, \
"Data columns must be same length, but lens are {}".format(lengths)
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
else:
self.count = len(self.data[k])
# Keeps track of new columns added after initial ones.
self.new_columns = []
@staticmethod
@PublicAPI
@ -88,11 +102,21 @@ class SampleBatch:
"""
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
seq_lens = []
concat_samples = []
for s in samples:
if s.count > 0:
concat_samples.append(s)
if s.seq_lens is not None:
seq_lens.extend(s.seq_lens)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = concat_aligned([s[k] for s in samples])
return SampleBatch(out)
for k in concat_samples[0].keys():
out[k] = concat_aligned(
[s[k] for s in concat_samples],
time_major=concat_samples[0].time_major)
return SampleBatch(
out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major)
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
@ -222,8 +246,18 @@ class SampleBatch:
SampleBatch: A new SampleBatch, which has a slice of this batch's
data.
"""
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
if self.time_major is not None:
return SampleBatch(
{k: v[:, start:end]
for k, v in self.data.items()},
_seq_lens=self.seq_lens[start:end],
_time_major=self.time_major)
else:
return SampleBatch(
{k: v[start:end]
for k, v in self.data.items()},
_seq_lens=None,
_time_major=self.time_major)
@PublicAPI
def timeslices(self, k: int) -> List["SampleBatch"]:
@ -290,7 +324,7 @@ class SampleBatch:
key (str): The key (column name) to return.
Returns:
TensorType]: The data under the given key.
TensorType: The data under the given key.
"""
return self.data[key]
@ -302,6 +336,8 @@ class SampleBatch:
key (str): The column name to set a value for.
item (TensorType): The data to insert.
"""
if key not in self.data:
self.new_columns.append(key)
self.data[key] = item
@DeveloperAPI

View file

@ -1,84 +0,0 @@
import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import framework_iterator
class TestTrajectoryViewAPI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_plain(self):
config = ppo.DEFAULT_CONFIG.copy()
for _ in framework_iterator(config, frameworks="torch"):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements()
view_req_policy = policy.training_view_requirements()
assert len(view_req_model) == 1
assert len(view_req_policy) == 6
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
SampleBatch.VF_PREDS
]:
assert key in view_req_policy
# None of the view cols has a special underlying data_col,
# except next-obs.
if key != SampleBatch.NEXT_OBS:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
trainer.stop()
def test_lstm_prev_actions_and_rewards(self):
config = ppo.DEFAULT_CONFIG.copy()
config["model"] = config["model"].copy()
# Activate LSTM + prev-action + rewards.
config["model"]["use_lstm"] = True
config["model"]["lstm_use_prev_action_reward"] = True
for _ in framework_iterator(config, frameworks="torch"):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements()
view_req_policy = policy.training_view_requirements()
assert len(view_req_model) == 3 # obs, prev_a, prev_r
assert len(view_req_policy) == 8
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert key in view_req_policy
if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -104,16 +104,21 @@ class TorchPolicy(Policy):
"""
self.framework = "torch"
super().__init__(observation_space, action_space, config)
if torch.cuda.is_available() and ray.get_gpu_ids(as_str=True):
if torch.cuda.is_available() and ray.get_gpu_ids():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model = model.to(self.device)
# Combine view_requirements for Model and Policy.
self.view_requirements = {
**self.model.inference_view_requirements(),
**self.training_view_requirements(),
}
self.training_view_requirements = dict(
**{
SampleBatch.ACTIONS: ViewRequirement(
space=self.action_space, shift=0),
SampleBatch.REWARDS: ViewRequirement(shift=0),
SampleBatch.DONES: ViewRequirement(shift=0),
},
**self.model.inference_view_requirements)
self.exploration = self._create_exploration()
self.unwrapped_model = model # used to support DistributedDataParallel
self._loss = loss
@ -131,17 +136,6 @@ class TorchPolicy(Policy):
callable(get_batch_divisibility_req) else \
(get_batch_divisibility_req or 1)
@override(Policy)
def training_view_requirements(self):
if hasattr(self, "view_requirements"):
return self.view_requirements
return {
SampleBatch.ACTIONS: ViewRequirement(
space=self.action_space, shift=0),
SampleBatch.REWARDS: ViewRequirement(shift=0),
SampleBatch.DONES: ViewRequirement(shift=0),
}
@override(Policy)
@DeveloperAPI
def compute_actions(
@ -204,9 +198,11 @@ class TorchPolicy(Policy):
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# TODO: (sven) support RNNs w/ fast sampling.
state_batches = []
seq_lens = None
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_" in k[:6]
]
seq_lens = np.array([1] * len(input_dict["obs"])) \
if state_batches else None
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
@ -340,7 +336,9 @@ class TorchPolicy(Policy):
postprocessed_batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req)
batch_divisibility_req=self.batch_divisibility_req,
_use_trajectory_view_api=self.config["_use_trajectory_view_api"],
)
train_batch = self._lazy_tensor_dict(postprocessed_batch)
@ -359,6 +357,7 @@ class TorchPolicy(Policy):
loss_out, train_batch)
assert len(loss_out) == len(self._optimizers)
# assert not any(torch.isnan(l) for l in loss_out)
fetches = self.extra_compute_grad_fetches()

View file

@ -224,16 +224,13 @@ def build_torch_policy(
get_batch_divisibility_req=get_batch_divisibility_req,
)
if callable(training_view_requirements_fn):
self.training_view_requirements.update(
training_view_requirements_fn(self))
if after_init:
after_init(self, obs_space, action_space, config)
@override(TorchPolicy)
def training_view_requirements(self):
req = super().training_view_requirements()
if callable(training_view_requirements_fn):
req.update(training_view_requirements_fn(self))
return req
@override(Policy)
def postprocess_trajectory(self,
sample_batch,

View file

@ -29,8 +29,7 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0,
created_during_postprocessing: bool = False):
shift: Union[int, List[int]] = 0):
"""Initializes a ViewRequirement object.
Args:
@ -47,11 +46,8 @@ class ViewRequirement:
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
created_during_postprocessing (bool): Whether this column only gets
created during postprocessing.
"""
self.data_col = data_col
self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.shift = shift
self.created_during_postprocessing = created_during_postprocessing

View file

@ -63,7 +63,8 @@ def minibatches(samples, sgd_minibatch_size):
raise NotImplementedError(
"Minibatching not implemented for multi-agent in simple mode")
if "state_in_0" in samples.data:
# Replace with `if samples.seq_lens` check.
if "state_in_0" in samples.data or "state_out_0" in samples.data:
if log_once("not_shuffling_rnn_data_in_simple_mode"):
logger.warning("Not shuffling RNN data for SGD in simple mode")
else:
@ -71,9 +72,22 @@ def minibatches(samples, sgd_minibatch_size):
i = 0
slices = []
while i < samples.count:
slices.append((i, i + sgd_minibatch_size))
i += sgd_minibatch_size
if samples.seq_lens:
seq_no = 0
while i < samples.count:
seq_no_end = seq_no
actual_count = 0
while actual_count < sgd_minibatch_size and len(
samples.seq_lens) > seq_no_end:
actual_count += samples.seq_lens[seq_no_end]
seq_no_end += 1
slices.append((seq_no, seq_no_end))
i += actual_count
seq_no = seq_no_end
else:
while i < samples.count:
slices.append((i, i + sgd_minibatch_size))
i += sgd_minibatch_size
random.shuffle(slices)
for i, j in slices:
@ -100,7 +114,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)
fetches = {}
for policy_id, policy in policies.items():
for policy_id in policies.keys():
if policy_id not in samples.policy_batches:
continue

View file

@ -43,6 +43,9 @@ EnvID = int
# Represents an episode id.
EpisodeID = int
# Represents an "unroll" (maybe across different sub-envs in a vector env).
UnrollID = int
# A dict keyed by agent ids, e.g. {"agent-1": value}.
MultiAgentDict = Dict[AgentID, Any]