mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Trajectory view API - 03 Fast LSTM + prev actions/rewards (#9950)
This commit is contained in:
parent
92664249e8
commit
e968b52cb7
25 changed files with 1230 additions and 413 deletions
34
rllib/BUILD
34
rllib/BUILD
|
@ -22,13 +22,14 @@
|
|||
# (problems: 10min timeout, not respecting ray/ci/keep_alive.sh, or even
|
||||
# `travis_wait n`, etc..).
|
||||
|
||||
# Our travis.yml file executes all these tests in 6 different jobs, which are:
|
||||
# Our travis.yml file executes all these tests in 7 different jobs, which are:
|
||||
# 1) everything in a) using tf2.x
|
||||
# 2) everything in a) using tf1.x
|
||||
# 3) everything in b) c) d) and e)
|
||||
# 4) everything in g)
|
||||
# 5) f), BUT only those tagged `tests_dir_A` to `tests_dir_L`
|
||||
# 6) f), BUT only those tagged `tests_dir_M` to `tests_dir_Z`
|
||||
# 3) everything in a) using torch
|
||||
# 4) everything in b) c) d) and e)
|
||||
# 5) everything in g)
|
||||
# 6) f), BUT only those tagged `tests_dir_A` to `tests_dir_[some letter]`
|
||||
# 7) f), BUT only those tagged `tests_dir_[some letter]` to `tests_dir_Z`
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
@ -1024,6 +1025,22 @@ py_test(
|
|||
srcs = ["models/tests/test_attention_nets.py"]
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Evaluation components
|
||||
# rllib/evaluation/
|
||||
#
|
||||
# Tag: evaluation
|
||||
# --------------------------------------------------------------------
|
||||
# mysteriously times out on travis.
|
||||
#py_test(
|
||||
# name = "evaluation/tests/test_trajectory_view_api",
|
||||
# tags = ["evaluation"],
|
||||
# size = "medium",
|
||||
# srcs = ["evaluation/tests/test_trajectory_view_api.py"]
|
||||
#)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Optimizers and Memories
|
||||
# rllib/execution/
|
||||
|
@ -1059,13 +1076,6 @@ py_test(
|
|||
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "policy/tests/test_trajectory_view_api",
|
||||
tags = ["policy"],
|
||||
size = "small",
|
||||
srcs = ["policy/tests/test_trajectory_view_api.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Utils:
|
||||
# rllib/utils/
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from typing import Dict
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.evaluation import MultiAgentEpisode
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation import RolloutWorker
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class DefaultCallbacks:
|
||||
|
@ -27,7 +30,7 @@ class DefaultCallbacks:
|
|||
"a class extending rllib.agents.callbacks.DefaultCallbacks")
|
||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||
|
||||
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
"""Callback run on the rollout worker before each episode starts.
|
||||
|
@ -52,7 +55,7 @@ class DefaultCallbacks:
|
|||
"episode": episode,
|
||||
})
|
||||
|
||||
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
"""Runs on each episode step.
|
||||
|
||||
|
@ -73,7 +76,7 @@ class DefaultCallbacks:
|
|||
"episode": episode
|
||||
})
|
||||
|
||||
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
"""Runs when an episode is done.
|
||||
|
@ -99,7 +102,7 @@ class DefaultCallbacks:
|
|||
})
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self, worker: RolloutWorker, episode: MultiAgentEpisode,
|
||||
self, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
||||
agent_id: AgentID, policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, SampleBatch], **kwargs):
|
||||
|
@ -133,7 +136,7 @@ class DefaultCallbacks:
|
|||
"all_pre_batches": original_batches,
|
||||
})
|
||||
|
||||
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
|
||||
def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch,
|
||||
**kwargs):
|
||||
"""Called at the end RolloutWorker.sample().
|
||||
|
||||
|
|
|
@ -117,7 +117,10 @@ def ppo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
mask = None
|
||||
if state:
|
||||
max_seq_len = torch.max(train_batch["seq_lens"])
|
||||
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = sequence_mask(
|
||||
train_batch["seq_lens"],
|
||||
max_seq_len,
|
||||
time_major=model.is_time_major())
|
||||
mask = torch.reshape(mask, [-1])
|
||||
|
||||
policy.loss_obj = PPOLoss(
|
||||
|
@ -221,6 +224,12 @@ def training_view_requirements_fn(policy):
|
|||
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
|
||||
# VF preds are needed for the loss.
|
||||
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
|
||||
# Needed for postprocessing.
|
||||
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
|
||||
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
|
||||
# Created during postprocessing.
|
||||
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
|
||||
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1082,6 +1082,11 @@ class Trainer(Trainable):
|
|||
raise ValueError(
|
||||
"`_use_trajectory_view_api` only supported for PyTorch so "
|
||||
"far!")
|
||||
elif not config.get("_use_trajectory_view_api") and \
|
||||
config.get("model", {}).get("_time_major"):
|
||||
raise ValueError("`model._time_major` only supported "
|
||||
"iff `_use_trajectory_view_api` is True!")
|
||||
|
||||
if "policy_graphs" in config["multiagent"]:
|
||||
deprecation_warning("policy_graphs", "policies")
|
||||
# Backwards compatibility.
|
||||
|
|
2
rllib/env/policy_client.py
vendored
2
rllib/env/policy_client.py
vendored
|
@ -10,7 +10,6 @@ import time
|
|||
from typing import Union, Optional
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
@ -337,6 +336,7 @@ def _create_embedded_rollout_worker(kwargs, send_fn):
|
|||
real_env_creator = kwargs["env_creator"]
|
||||
kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
|
||||
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
rollout_worker = RolloutWorker(**kwargs)
|
||||
inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
|
||||
inference_thread.start()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.per_policy_sample_collector import \
|
||||
|
@ -16,6 +15,9 @@ from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
|||
TensorType
|
||||
from ray.util.debug import log_once
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -38,7 +40,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
def __init__(
|
||||
self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
callbacks: DefaultCallbacks,
|
||||
callbacks: "DefaultCallbacks",
|
||||
# TODO: (sven) make `num_agents` flexibly grow in size.
|
||||
num_agents: int = 100,
|
||||
num_timesteps=None,
|
||||
|
@ -64,8 +66,8 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
num_agents = 1000
|
||||
self.num_agents = int(num_agents)
|
||||
|
||||
# Collect SampleBatches per-policy in PolicyTrajectories objects.
|
||||
self.rollout_sample_collectors = {}
|
||||
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
|
||||
self.policy_sample_collectors = {}
|
||||
for pid, policy in policy_map.items():
|
||||
# Figure out max-shifts (before and after).
|
||||
view_reqs = policy.training_view_requirements
|
||||
|
@ -86,7 +88,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
elif num_timesteps is not None:
|
||||
kwargs["num_timesteps"] = num_timesteps
|
||||
|
||||
self.rollout_sample_collectors[pid] = _PerPolicySampleCollector(
|
||||
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
|
||||
num_agents=self.num_agents,
|
||||
shift_before=-max_shift_before,
|
||||
shift_after=max_shift_after,
|
||||
|
@ -109,7 +111,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
assert self.agent_to_policy[agent_id] == policy_id
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
self.rollout_sample_collectors[policy_id].add_init_obs(
|
||||
self.policy_sample_collectors[policy_id].add_init_obs(
|
||||
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
|
||||
|
||||
@override(_SampleCollector)
|
||||
|
@ -117,7 +119,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
agent_id: AgentID, env_id: EnvID,
|
||||
policy_id: PolicyID, agent_done: bool,
|
||||
values: Dict[str, TensorType]) -> None:
|
||||
assert policy_id in self.rollout_sample_collectors
|
||||
assert policy_id in self.policy_sample_collectors
|
||||
|
||||
# Make sure our mappings are up to date.
|
||||
if agent_id not in self.agent_to_policy:
|
||||
|
@ -130,13 +132,13 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
values["agent_id"] = agent_id
|
||||
|
||||
# Add action/reward/next-obs (and other data) to Trajectory.
|
||||
self.rollout_sample_collectors[policy_id].add_action_reward_next_obs(
|
||||
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
|
||||
episode_id, agent_id, env_id, agent_done, values)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def total_env_steps(self) -> int:
|
||||
return sum(a.timesteps_since_last_reset
|
||||
for a in self.rollout_sample_collectors.values())
|
||||
for a in self.policy_sample_collectors.values())
|
||||
|
||||
def total(self):
|
||||
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
|
||||
|
@ -148,7 +150,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
Dict[str, TensorType]:
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements
|
||||
return self.rollout_sample_collectors[
|
||||
return self.policy_sample_collectors[
|
||||
policy_id].get_inference_input_dict(view_reqs)
|
||||
|
||||
@override(_SampleCollector)
|
||||
|
@ -161,7 +163,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
# Loop through each per-policy collector and create a view (for each
|
||||
# agent as SampleBatch) from its buffers for post-processing
|
||||
all_agent_batches = {}
|
||||
for pid, rc in self.rollout_sample_collectors.items():
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
policy = self.policy_map[pid]
|
||||
view_reqs = policy.training_view_requirements
|
||||
agent_batches = rc.get_postprocessing_sample_batches(
|
||||
|
@ -211,7 +213,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
|
||||
@override(_SampleCollector)
|
||||
def check_missing_dones(self, episode_id: EpisodeID) -> None:
|
||||
for pid, rc in self.rollout_sample_collectors.items():
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
for agent_key in rc.agent_key_to_slot.keys():
|
||||
# Only check for given episode and only for last chunk
|
||||
# (all previous chunks for that agent in the episode are
|
||||
|
@ -235,7 +237,7 @@ class _MultiAgentSampleCollector(_SampleCollector):
|
|||
def get_multi_agent_batch_and_reset(self):
|
||||
self.postprocess_trajectories_so_far()
|
||||
policy_batches = {}
|
||||
for pid, rc in self.rollout_sample_collectors.items():
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
policy = self.policy_map[pid]
|
||||
view_reqs = policy.training_view_requirements
|
||||
policy_batches[pid] = rc.get_train_sample_batch_and_reset(
|
||||
|
|
|
@ -103,7 +103,13 @@ class _PerPolicySampleCollector:
|
|||
self._next_agent_slot()
|
||||
|
||||
if SampleBatch.OBS not in self.buffers:
|
||||
self._build_buffers(single_row={SampleBatch.OBS: init_obs})
|
||||
self._build_buffers(
|
||||
single_row={
|
||||
SampleBatch.OBS: init_obs,
|
||||
SampleBatch.EPS_ID: episode_id,
|
||||
SampleBatch.AGENT_INDEX: agent_id,
|
||||
"env_id": env_id,
|
||||
})
|
||||
if self.time_major:
|
||||
self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \
|
||||
init_obs
|
||||
|
@ -262,12 +268,12 @@ class _PerPolicySampleCollector:
|
|||
batch = sample_batch_data[agent_key]
|
||||
|
||||
for view_col, view_req in view_reqs.items():
|
||||
data_col = view_req.data_col or view_col
|
||||
# Skip columns that will only get added through postprocessing
|
||||
# (these may not even exist yet).
|
||||
if view_req.created_during_postprocessing:
|
||||
if data_col not in self.buffers:
|
||||
continue
|
||||
|
||||
data_col = view_req.data_col or view_col
|
||||
shift = view_req.shift
|
||||
if data_col == SampleBatch.OBS:
|
||||
shift -= 1
|
||||
|
@ -289,20 +295,22 @@ class _PerPolicySampleCollector:
|
|||
SampleBatch: Returns the accumulated sample batch for this
|
||||
policy.
|
||||
"""
|
||||
seq_lens = [
|
||||
seq_lens_w_0s = [
|
||||
self.agent_key_to_timestep[k] - self.shift_before
|
||||
for k in self.slot_to_agent_key if k is not None
|
||||
]
|
||||
first_zero_len = len(seq_lens)
|
||||
if seq_lens[-1] == 0:
|
||||
first_zero_len = seq_lens.index(0)
|
||||
# We have an agent-axis buffer "rollover" (new SampleBatch will be
|
||||
# built from last n agent records plus first m agent records in
|
||||
# buffer).
|
||||
if self.agent_slot_cursor < self.sample_batch_offset:
|
||||
rollover = -(self.num_agents - self.sample_batch_offset)
|
||||
seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover]
|
||||
first_zero_len = len(seq_lens_w_0s)
|
||||
if seq_lens_w_0s[-1] == 0:
|
||||
first_zero_len = seq_lens_w_0s.index(0)
|
||||
# Assert that all zeros lie at the end of the seq_lens array.
|
||||
try:
|
||||
assert all(seq_lens[i] == 0
|
||||
for i in range(first_zero_len, len(seq_lens)))
|
||||
except AssertionError as e:
|
||||
print()
|
||||
raise e
|
||||
assert all(seq_lens_w_0s[i] == 0
|
||||
for i in range(first_zero_len, len(seq_lens_w_0s)))
|
||||
|
||||
t_start = self.shift_before
|
||||
t_end = t_start + self.num_timesteps
|
||||
|
@ -311,8 +319,8 @@ class _PerPolicySampleCollector:
|
|||
# actually already has at least 1 timestep of data (thus it excludes
|
||||
# just-rolled over chunks (which only have the initial obs in them)).
|
||||
valid_agent_cursor = \
|
||||
(self.agent_slot_cursor - (len(seq_lens) - first_zero_len)) % \
|
||||
self.num_agents
|
||||
(self.agent_slot_cursor -
|
||||
(len(seq_lens_w_0s) - first_zero_len)) % self.num_agents
|
||||
|
||||
# Construct the view dict.
|
||||
view = {}
|
||||
|
@ -320,12 +328,13 @@ class _PerPolicySampleCollector:
|
|||
data_col = view_req.data_col or view_col
|
||||
assert data_col in self.buffers
|
||||
# For OBS, indices must be shifted by -1.
|
||||
extra_shift = 0 if data_col != SampleBatch.OBS else -1
|
||||
shift = view_req.shift
|
||||
shift += 0 if data_col != SampleBatch.OBS else -1
|
||||
# If agent_slot has been rolled-over to beginning, we have to copy
|
||||
# here.
|
||||
if valid_agent_cursor < self.sample_batch_offset:
|
||||
time_slice = self.buffers[data_col][t_start + extra_shift:
|
||||
t_end + extra_shift]
|
||||
time_slice = self.buffers[data_col][t_start + shift:t_end +
|
||||
shift]
|
||||
one_ = time_slice[:, self.sample_batch_offset:]
|
||||
two_ = time_slice[:, :valid_agent_cursor]
|
||||
if torch and isinstance(time_slice, torch.Tensor):
|
||||
|
@ -335,17 +344,15 @@ class _PerPolicySampleCollector:
|
|||
else:
|
||||
view[view_col] = \
|
||||
self.buffers[data_col][
|
||||
t_start + extra_shift:t_end + extra_shift,
|
||||
t_start + shift:t_end + shift,
|
||||
self.sample_batch_offset:valid_agent_cursor]
|
||||
|
||||
# Copy all still ongoing trajectories to new agent slots
|
||||
# (including the ones that just started (are seq_len=0)).
|
||||
new_chunk_args = []
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
for i, seq_len in enumerate(seq_lens_w_0s):
|
||||
if seq_len < self.num_timesteps:
|
||||
agent_slot = self.sample_batch_offset + i
|
||||
if agent_slot >= self.num_agents:
|
||||
agent_slot = agent_slot % self.num_agents
|
||||
agent_slot = (self.sample_batch_offset + i) % self.num_agents
|
||||
if not self.buffers[SampleBatch.
|
||||
DONES][seq_len - 1 +
|
||||
self.shift_before][agent_slot]:
|
||||
|
@ -354,9 +361,9 @@ class _PerPolicySampleCollector:
|
|||
(agent_slot, agent_key,
|
||||
self.agent_key_to_timestep[agent_key]))
|
||||
# Cut out all 0 seq-lens.
|
||||
seq_lens = seq_lens[:first_zero_len]
|
||||
seq_lens = seq_lens_w_0s[:first_zero_len]
|
||||
batch = SampleBatch(
|
||||
view, _seq_lens=np.array(seq_lens), _time_major=True)
|
||||
view, _seq_lens=np.array(seq_lens), _time_major=self.time_major)
|
||||
|
||||
# Reset everything for new data.
|
||||
self.postprocessed_agents = [False] * self.num_agents
|
||||
|
@ -376,9 +383,14 @@ class _PerPolicySampleCollector:
|
|||
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
|
||||
"""Builds the internal data buffers based on a single given row.
|
||||
|
||||
This may be called several times in the lifetime of this instance
|
||||
to add new columns to the buffer. Columns in `single_row` that already
|
||||
exist in the buffer will be ignored.
|
||||
|
||||
Args:
|
||||
single_row (Dict[str, TensorType]): A single datarow with one or
|
||||
more columns (str as key, np.ndarray|tensor as data).
|
||||
more columns (str as key, np.ndarray|tensor as data) to be used
|
||||
as template to build the pre-allocated buffer.
|
||||
"""
|
||||
time_size = self.num_timesteps + self.shift_before + self.shift_after
|
||||
for col, data in single_row.items():
|
||||
|
|
|
@ -515,7 +515,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
rollout_fragment_length=rollout_fragment_length,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack_multiple_episodes_in_batch=pack,
|
||||
multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
|
@ -538,7 +538,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
rollout_fragment_length=rollout_fragment_length,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack_multiple_episodes_in_batch=pack,
|
||||
multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon,
|
||||
|
|
|
@ -5,14 +5,17 @@ import numpy as np
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, \
|
||||
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
|
||||
TYPE_CHECKING, Union
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.multi_agent_sample_collector import \
|
||||
_MultiAgentSampleCollector
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.sample_collector import _SampleCollector
|
||||
from ray.rllib.policy.policy import clip_action, Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
@ -22,6 +25,7 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
|||
from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \
|
||||
unbatch
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
@ -51,14 +55,23 @@ class _PerfStats:
|
|||
def __init__(self):
|
||||
self.iters = 0
|
||||
self.env_wait_time = 0.0
|
||||
self.processing_time = 0.0
|
||||
self.raw_obs_processing_time = 0.0
|
||||
self.inference_time = 0.0
|
||||
self.action_processing_time = 0.0
|
||||
|
||||
def get(self):
|
||||
# Mean multiplicator (1000 = ms -> sec).
|
||||
factor = 1000 / self.iters
|
||||
return {
|
||||
"mean_env_wait_ms": self.env_wait_time * 1000 / self.iters,
|
||||
"mean_processing_ms": self.processing_time * 1000 / self.iters,
|
||||
"mean_inference_ms": self.inference_time * 1000 / self.iters
|
||||
# Waiting for environment (during poll).
|
||||
"mean_env_wait_ms": self.env_wait_time * factor,
|
||||
# Raw observation preprocessing.
|
||||
"mean_raw_obs_processing_ms": self.raw_obs_processing_time *
|
||||
factor,
|
||||
# Computing actions through policy.
|
||||
"mean_inference_ms": self.inference_time * factor,
|
||||
# Processing actions (to be sent to env, e.g. clipping).
|
||||
"mean_action_processing_ms": self.action_processing_time * factor,
|
||||
}
|
||||
|
||||
|
||||
|
@ -108,7 +121,7 @@ class SyncSampler(SamplerInput):
|
|||
rollout_fragment_length: int,
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
pack_multiple_episodes_in_batch: bool = False,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
soft_horizon: bool = False,
|
||||
|
@ -136,7 +149,7 @@ class SyncSampler(SamplerInput):
|
|||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
|
@ -165,14 +178,20 @@ class SyncSampler(SamplerInput):
|
|||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = _PerfStats()
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _MultiAgentSampleCollector(
|
||||
policies, callbacks)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
# Create the rollout generator to use for calls to `get_data()`.
|
||||
self.rollout_provider = _env_runner(
|
||||
worker, self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack_multiple_episodes_in_batch, callbacks, tf_sess,
|
||||
self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
|
||||
_use_trajectory_view_api)
|
||||
multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats,
|
||||
soft_horizon, no_done_at_end, observation_fn,
|
||||
_use_trajectory_view_api, self.sample_collector)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
@override(SamplerInput)
|
||||
|
@ -226,7 +245,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
rollout_fragment_length: int,
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
pack_multiple_episodes_in_batch: bool = False,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
blackhole_outputs: bool = False,
|
||||
|
@ -255,7 +274,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
|
@ -293,7 +312,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.obs_filters = obs_filters
|
||||
self.clip_rewards = clip_rewards
|
||||
self.daemon = True
|
||||
self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.tf_sess = tf_sess
|
||||
self.callbacks = callbacks
|
||||
self.clip_actions = clip_actions
|
||||
|
@ -304,6 +323,11 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.shutdown = False
|
||||
self.observation_fn = observation_fn
|
||||
self._use_trajectory_view_api = _use_trajectory_view_api
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _MultiAgentSampleCollector(
|
||||
policies, callbacks)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
@override(threading.Thread)
|
||||
def run(self):
|
||||
|
@ -325,8 +349,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.worker, self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack_multiple_episodes_in_batch,
|
||||
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
|
||||
self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||
self.no_done_at_end, self.observation_fn,
|
||||
self._use_trajectory_view_api)
|
||||
while not self.shutdown:
|
||||
|
@ -385,14 +409,16 @@ def _env_runner(
|
|||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
clip_actions: bool,
|
||||
pack_multiple_episodes_in_batch: bool,
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
tf_sess: Optional["tf.Session"],
|
||||
perf_stats: _PerfStats,
|
||||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
_use_trajectory_view_api: bool = False) -> Iterable[SampleBatchType]:
|
||||
_use_trajectory_view_api: bool = False,
|
||||
_sample_collector: Optional[_SampleCollector] = None,
|
||||
) -> Iterable[SampleBatchType]:
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
|
@ -413,7 +439,7 @@ def _env_runner(
|
|||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
`rollout_fragment_length` in size.
|
||||
clip_actions (bool): Whether to clip actions to the space range.
|
||||
|
@ -430,6 +456,8 @@ def _env_runner(
|
|||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` to make generic trajectory views
|
||||
available to Models. Default: False.
|
||||
_sample_collector (Optional[_SampleCollector]): An optional
|
||||
_SampleCollector object to use
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
|
@ -471,6 +499,8 @@ def _env_runner(
|
|||
def get_batch_builder():
|
||||
if batch_builder_pool:
|
||||
return batch_builder_pool.pop()
|
||||
elif _use_trajectory_view_api:
|
||||
return None
|
||||
else:
|
||||
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
|
||||
callbacks)
|
||||
|
@ -495,6 +525,7 @@ def _env_runner(
|
|||
return episode
|
||||
|
||||
active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode)
|
||||
eval_results = None
|
||||
|
||||
while True:
|
||||
perf_stats.iters += 1
|
||||
|
@ -514,39 +545,73 @@ def _env_runner(
|
|||
t1 = time.time()
|
||||
# type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
|
||||
# List[Union[RolloutMetrics, SampleBatchType]]
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
batch_builder_pool=batch_builder_pool,
|
||||
active_episodes=active_episodes,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
infos=infos,
|
||||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
if _use_trajectory_view_api:
|
||||
active_envs, to_eval, outputs = \
|
||||
_process_observations_w_trajectory_view_api(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
prev_policy_outputs=eval_results,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
infos=infos,
|
||||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
perf_stats=perf_stats,
|
||||
_sample_collector=_sample_collector,
|
||||
)
|
||||
else:
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
batch_builder_pool=batch_builder_pool,
|
||||
active_episodes=active_episodes,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
infos=infos,
|
||||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
perf_stats=perf_stats,
|
||||
)
|
||||
perf_stats.raw_obs_processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
||||
# Do batched policy eval (accross vectorized envs).
|
||||
t2 = time.time()
|
||||
# type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
|
||||
eval_results = _do_policy_eval(
|
||||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||
if _use_trajectory_view_api:
|
||||
eval_results = _do_policy_eval_w_trajectory_view_api(
|
||||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
_sample_collector=_sample_collector,
|
||||
tf_sess=tf_sess,
|
||||
)
|
||||
else:
|
||||
eval_results = _do_policy_eval(
|
||||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
)
|
||||
perf_stats.inference_time += time.time() - t2
|
||||
|
||||
# Process results and update episode state.
|
||||
|
@ -560,8 +625,10 @@ def _env_runner(
|
|||
off_policy_actions=off_policy_actions,
|
||||
policies=policies,
|
||||
clip_actions=clip_actions,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||
perf_stats.processing_time += time.time() - t3
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
_sample_collector=_sample_collector,
|
||||
)
|
||||
perf_stats.action_processing_time += time.time() - t3
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
|
@ -571,6 +638,7 @@ def _env_runner(
|
|||
|
||||
|
||||
def _process_observations(
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
|
@ -584,12 +652,12 @@ def _process_observations(
|
|||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
rollout_fragment_length: int,
|
||||
pack_multiple_episodes_in_batch: bool,
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
_use_trajectory_view_api: bool = False
|
||||
perf_stats: _PerfStats,
|
||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||
RolloutMetrics, SampleBatchType]]]:
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
@ -602,8 +670,11 @@ def _process_observations(
|
|||
SampleBatchBuilder object for recycling.
|
||||
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
||||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
|
||||
prev_policy_outputs (Dict[str,List]): The prev policy output dict
|
||||
(by policy-id -> List[action, state outs, extra fetches]).
|
||||
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
|
||||
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
|
||||
call.
|
||||
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
rewards tensor, returned by a `BaseEnv.poll()` call.
|
||||
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
|
@ -618,7 +689,7 @@ def _process_observations(
|
|||
rollout_fragment_length (int): Number of episode steps before
|
||||
`SampleBatch` is yielded. Set to infinity to yield complete
|
||||
episodes.
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
`rollout_fragment_length` in size.
|
||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||
|
@ -628,9 +699,6 @@ def _process_observations(
|
|||
and instead record done=False.
|
||||
observation_fn (ObservationFunction): Optional multi-agent
|
||||
observation func to use for preprocessing observations.
|
||||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` to make generic trajectory views
|
||||
available to Models. Default: False.
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
|
@ -652,20 +720,21 @@ def _process_observations(
|
|||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
is_new_episode: bool = env_id not in active_episodes
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
batch_builder = episode.batch_builder
|
||||
if not is_new_episode:
|
||||
episode.length += 1
|
||||
episode.batch_builder.count += 1
|
||||
batch_builder.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
if (episode.batch_builder.total() > large_batch_threshold
|
||||
if (batch_builder.total() > large_batch_threshold
|
||||
and log_once("large_batch_warning")):
|
||||
logger.warning(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
episode.batch_builder.total(),
|
||||
episode.batch_builder.count) + "are buffered in "
|
||||
"the sampler. If this is more than you expected, check "
|
||||
"that you set a horizon on your environment correctly and "
|
||||
"that it terminates at some point. "
|
||||
batch_builder.total(), batch_builder.count) +
|
||||
"are buffered in "
|
||||
"the sampler. If this is more than you expected, check that "
|
||||
"that you set a horizon on your environment correctly and that"
|
||||
" it terminates at some point. "
|
||||
"Note: In multi-agent environments, `rollout_fragment_length` "
|
||||
"sets the batch size based on environment steps, not the "
|
||||
"steps of "
|
||||
|
@ -725,12 +794,12 @@ def _process_observations(
|
|||
|
||||
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
infos[env_id].get(agent_id, {}),
|
||||
episode.rnn_state_for(agent_id),
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0))
|
||||
item = PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
infos[env_id].get(agent_id, {}),
|
||||
episode.rnn_state_for(agent_id),
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0)
|
||||
to_eval[policy_id].append(item)
|
||||
|
||||
last_observation: EnvObsType = episode.last_observation_for(
|
||||
agent_id)
|
||||
|
@ -741,7 +810,7 @@ def _process_observations(
|
|||
# Record transition info if applicable.
|
||||
if (last_observation is not None and infos[env_id].get(
|
||||
agent_id, {}).get("training_enabled", True)):
|
||||
episode.batch_builder.add_values(
|
||||
batch_builder.add_values(
|
||||
agent_id,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
|
@ -767,26 +836,26 @@ def _process_observations(
|
|||
# - all-agents-done and not packing multiple episodes into one
|
||||
# (batch_mode="complete_episodes")
|
||||
# - or if we've exceeded the rollout_fragment_length.
|
||||
if episode.batch_builder.has_pending_agent_data():
|
||||
if batch_builder.has_pending_agent_data():
|
||||
# Sanity check, whether all agents have done=True, if done[__all__]
|
||||
# is True.
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
episode.batch_builder.check_missing_dones()
|
||||
batch_builder.check_missing_dones()
|
||||
|
||||
# Reached end of episode and we are not allowed to pack the
|
||||
# next episode into the same SampleBatch -> Build the SampleBatch
|
||||
# and add it to "outputs".
|
||||
if (all_agents_done and not pack_multiple_episodes_in_batch) or \
|
||||
episode.batch_builder.count >= rollout_fragment_length:
|
||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
episode.batch_builder.postprocess_batch_so_far(episode)
|
||||
# Reached end of episode and we are not allowed to pack the
|
||||
# next episode into the same SampleBatch -> Build the SampleBatch
|
||||
# and add it to "outputs".
|
||||
if (all_agents_done and not multiple_episodes_in_batch) or \
|
||||
batch_builder.count >= rollout_fragment_length:
|
||||
outputs.append(batch_builder.build_and_reset(episode))
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
batch_builder.postprocess_batch_so_far(episode)
|
||||
|
||||
# Episode is done.
|
||||
if all_agents_done:
|
||||
# Handle episode termination.
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
# We can pass the BatchBuilder to recycling.
|
||||
batch_builder_pool.append(batch_builder)
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in policies.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
|
@ -834,14 +903,262 @@ def _process_observations(
|
|||
filtered_obs: EnvObsType = _get_or_raise(
|
||||
obs_filters, policy_id)(prep_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.last_info_for(agent_id) or {},
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
flatten_to_single_ndarray(
|
||||
policy.action_space.sample())), 0.0))
|
||||
|
||||
item = PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.last_info_for(agent_id) or {},
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
flatten_to_single_ndarray(
|
||||
policy.action_space.sample())), 0.0)
|
||||
to_eval[policy_id].append(item)
|
||||
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
||||
def _process_observations_w_trajectory_view_api(
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
|
||||
dict]],
|
||||
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
||||
rewards: Dict[EnvID, Dict[AgentID, float]],
|
||||
dones: Dict[EnvID, Dict[AgentID, bool]],
|
||||
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
|
||||
horizon: int,
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
rollout_fragment_length: int,
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
perf_stats: _PerfStats,
|
||||
_sample_collector: _SampleCollector,
|
||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||
RolloutMetrics, SampleBatchType]]]:
|
||||
"""Trajectory View API version of `_process_observations()`.
|
||||
TODO: (sven) Move docstring here once original function is deprecated.
|
||||
"""
|
||||
|
||||
# Output objects.
|
||||
active_envs: Set[EnvID] = set()
|
||||
to_eval: Set[PolicyID] = set()
|
||||
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
|
||||
|
||||
large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
|
||||
rollout_fragment_length != float("inf") else 5000
|
||||
|
||||
# For each environment.
|
||||
# type: EnvID, Dict[AgentID, EnvObsType]
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
is_new_episode: bool = env_id not in active_episodes
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
|
||||
if not is_new_episode:
|
||||
episode.length += 1
|
||||
_sample_collector.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
if (_sample_collector.total_env_steps() > large_batch_threshold
|
||||
and log_once("large_batch_warning")):
|
||||
logger.warning(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
_sample_collector.total_env_steps(),
|
||||
_sample_collector.count) + "are buffered in "
|
||||
"the sampler. If this is more than you expected, check that "
|
||||
"that you set a horizon on your environment correctly and that"
|
||||
" it terminates at some point. "
|
||||
"Note: In multi-agent environments, `rollout_fragment_length` "
|
||||
"sets the batch size based on environment steps, not the "
|
||||
"steps of "
|
||||
"individual agents, which can result in unexpectedly large "
|
||||
"batches. Also, you may be in evaluation waiting for your Env "
|
||||
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||
"does at some point.")
|
||||
|
||||
# Check episode termination conditions.
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
hit_horizon = (episode.length >= horizon
|
||||
and not dones[env_id]["__all__"])
|
||||
all_agents_done = True
|
||||
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
|
||||
base_env)
|
||||
if atari_metrics is not None:
|
||||
for m in atari_metrics:
|
||||
outputs.append(
|
||||
m._replace(custom_metrics=episode.custom_metrics))
|
||||
else:
|
||||
outputs.append(
|
||||
RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics, {},
|
||||
episode.hist_data))
|
||||
else:
|
||||
hit_horizon = False
|
||||
all_agents_done = False
|
||||
active_envs.add(env_id)
|
||||
|
||||
# Custom observation function is applied before preprocessing.
|
||||
if observation_fn:
|
||||
agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||
agent_obs=agent_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
if not isinstance(agent_obs, dict):
|
||||
raise ValueError(
|
||||
"observe() must return a dict of agent observations")
|
||||
|
||||
# For each agent in the environment.
|
||||
# type: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
assert agent_id != "__all__"
|
||||
policy_id: PolicyID = episode.policy_for(agent_id)
|
||||
prep_obs: EnvObsType = _get_or_raise(preprocessors,
|
||||
policy_id).transform(raw_obs)
|
||||
if log_once("prep_obs"):
|
||||
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
|
||||
|
||||
filtered_obs: EnvObsType = _get_or_raise(obs_filters,
|
||||
policy_id)(prep_obs)
|
||||
if log_once("filtered_obs"):
|
||||
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
||||
|
||||
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
|
||||
|
||||
last_observation: EnvObsType = episode.last_observation_for(
|
||||
agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_raw_obs(agent_id, raw_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
||||
|
||||
# Record transition info if applicable.
|
||||
if last_observation is None:
|
||||
_sample_collector.add_init_obs(episode.episode_id, agent_id,
|
||||
env_id, policy_id, filtered_obs)
|
||||
else:
|
||||
rc = _sample_collector.policy_sample_collectors[policy_id]
|
||||
eval_idx = rc.agent_key_to_forward_pass_index[(
|
||||
agent_id, episode.episode_id)]
|
||||
values_dict = {
|
||||
"t": episode.length - 1,
|
||||
"eps_id": episode.episode_id,
|
||||
"agent_index": episode._agent_index(agent_id),
|
||||
# Action (slot 0) taken at timestep t.
|
||||
"actions": prev_policy_outputs[policy_id][0][eval_idx],
|
||||
# Reward received after taking a at timestep t.
|
||||
"rewards": rewards[env_id][agent_id],
|
||||
# After taking a, did we reach terminal?
|
||||
"dones": (False if (no_done_at_end
|
||||
or (hit_horizon and soft_horizon)) else
|
||||
agent_done),
|
||||
# Next observation.
|
||||
"new_obs": filtered_obs,
|
||||
}
|
||||
# TODO: (sven) add env infos to buffers as well.
|
||||
for k, v in prev_policy_outputs[policy_id][2].items():
|
||||
values_dict[k] = v[eval_idx]
|
||||
for i, v in enumerate(prev_policy_outputs[policy_id][1]):
|
||||
values_dict["state_out_{}".format(i)] = v[eval_idx]
|
||||
_sample_collector.add_action_reward_next_obs(
|
||||
episode.episode_id, agent_id, env_id, policy_id,
|
||||
agent_done, values_dict)
|
||||
|
||||
if not agent_done:
|
||||
to_eval.add(policy_id)
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
callbacks.on_episode_step(
|
||||
worker=worker, base_env=base_env, episode=episode)
|
||||
|
||||
# Cut the batch if ...
|
||||
# - all-agents-done and not packing multiple episodes into one
|
||||
# (batch_mode="complete_episodes")
|
||||
# - or if we've exceeded the rollout_fragment_length.
|
||||
if _sample_collector.has_non_postprocessed_data():
|
||||
# Sanity check, whether all agents have done=True, if done[__all__]
|
||||
# is True.
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
_sample_collector.check_missing_dones(
|
||||
episode_id=episode.episode_id)
|
||||
|
||||
# Reached end of episode and we are not allowed to pack the
|
||||
# next episode into the same SampleBatch -> Build the SampleBatch
|
||||
# and add it to "outputs".
|
||||
if (all_agents_done and not multiple_episodes_in_batch) or \
|
||||
_sample_collector.count >= rollout_fragment_length:
|
||||
# TODO: (sven) Case: rollout_fragment_length reached: Do not
|
||||
# store any data in `episode` anymore
|
||||
# (useless for get_view_requirements when t<<-1, e.g.
|
||||
# attention), but keep last episode data around in
|
||||
# SampleBatchBuilder
|
||||
# to be able to still reference into it
|
||||
# should a model require this.
|
||||
outputs.append(_sample_collector.get_multi_agent_batch_and_reset())
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
_sample_collector.postprocess_trajectories_so_far(episode)
|
||||
|
||||
# Episode is done.
|
||||
if all_agents_done:
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in policies.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
# Call custom on_episode_end callback.
|
||||
callbacks.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
|
||||
else:
|
||||
del active_episodes[env_id]
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
|
||||
env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list.
|
||||
if horizon != float("inf"):
|
||||
raise ValueError(
|
||||
"Setting episode horizon requires reset() support "
|
||||
"from the environment.")
|
||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||
# Creates a new episode if this is not async return.
|
||||
# If reset is async, we will get its result in some future poll
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
if observation_fn:
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||
agent_obs=resetted_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
# type: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id: PolicyID = episode.policy_for(agent_id)
|
||||
prep_obs: EnvObsType = _get_or_raise(
|
||||
preprocessors, policy_id).transform(raw_obs)
|
||||
filtered_obs: EnvObsType = _get_or_raise(
|
||||
obs_filters, policy_id)(prep_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Add initial obs to buffer.
|
||||
_sample_collector.add_init_obs(episode.episode_id,
|
||||
agent_id, env_id, policy_id,
|
||||
filtered_obs)
|
||||
to_eval.add(policy_id)
|
||||
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
@ -852,7 +1169,6 @@ def _do_policy_eval(
|
|||
policies: Dict[PolicyID, Policy],
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
tf_sess=None,
|
||||
_use_trajectory_view_api=False
|
||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||
"""Call compute_actions on collected episode/model data to get next action.
|
||||
|
||||
|
@ -866,9 +1182,6 @@ def _do_policy_eval(
|
|||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||
batching TF policy evaluations.
|
||||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` procedure to collect samples.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
eval_results: dict of policy to compute_action() outputs.
|
||||
|
@ -888,15 +1201,15 @@ def _do_policy_eval(
|
|||
|
||||
# type: PolicyID, PolicyEvalData
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data]
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
# If tf (non eager) AND TFPolicy's compute_action method has not been
|
||||
# overridden -> Use `policy._build_compute_actions()`.
|
||||
# If tf (non eager) AND TFPolicy's compute_action method has not
|
||||
# been overridden -> Use `policy._build_compute_actions()`.
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicy.compute_actions.__code__):
|
||||
|
||||
obs_batch: List[EnvObsType] = [t.obs for t in eval_data]
|
||||
state_batches: StateBatch = _to_column_format(rnn_in)
|
||||
state_batches: StateBatch = _to_column_format(
|
||||
[t.rnn_state for t in eval_data])
|
||||
# TODO(ekl): how can we make info batch available to TF code?
|
||||
prev_action_batch = [t.prev_action for t in eval_data]
|
||||
prev_reward_batch = [t.prev_reward for t in eval_data]
|
||||
|
@ -909,6 +1222,7 @@ def _do_policy_eval(
|
|||
prev_reward_batch=prev_reward_batch,
|
||||
timestep=policy.global_timestep)
|
||||
else:
|
||||
rnn_in = [t.rnn_state for t in eval_data]
|
||||
rnn_in_cols: StateBatch = [
|
||||
np.stack([row[i] for row in rnn_in])
|
||||
for i in range(len(rnn_in[0]))
|
||||
|
@ -921,6 +1235,61 @@ def _do_policy_eval(
|
|||
info_batch=[t.info for t in eval_data],
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data],
|
||||
timestep=policy.global_timestep)
|
||||
|
||||
if builder:
|
||||
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
||||
for pid, v in pending_fetches.items():
|
||||
eval_results[pid] = builder.get(v)
|
||||
|
||||
if log_once("compute_actions_result"):
|
||||
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
||||
summarize(eval_results)))
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
def _do_policy_eval_w_trajectory_view_api(
|
||||
*,
|
||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||
policies: Dict[PolicyID, Policy],
|
||||
_sample_collector,
|
||||
tf_sess=None,
|
||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||
"""Call compute_actions on collected episode/model data to get next action.
|
||||
|
||||
Args:
|
||||
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
|
||||
IDs to lists of PolicyEvalData objects (items in these lists will
|
||||
be the batch's items for the model forward pass).
|
||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
||||
obj.
|
||||
_sample_collector (SampleCollector): The SampleCollector object to use.
|
||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||
batching TF policy evaluations.
|
||||
|
||||
Returns:
|
||||
eval_results: dict of policy to compute_action() outputs.
|
||||
"""
|
||||
|
||||
eval_results: Dict[PolicyID, TensorStructType] = {}
|
||||
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
pending_fetches: Dict[PolicyID, Any] = {}
|
||||
else:
|
||||
builder = None
|
||||
|
||||
if log_once("compute_actions_input"):
|
||||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
||||
for policy_id in to_eval:
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
||||
eval_results[policy_id] = \
|
||||
policy.compute_actions_from_input_dict(
|
||||
input_dict, timestep=policy.global_timestep)
|
||||
|
||||
if builder:
|
||||
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
||||
for pid, v in pending_fetches.items():
|
||||
|
@ -943,7 +1312,8 @@ def _process_policy_eval_results(
|
|||
off_policy_actions: MultiEnvDict,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
clip_actions: bool,
|
||||
_use_trajectory_view_api: bool = False
|
||||
_use_trajectory_view_api: bool = False,
|
||||
_sample_collector=None,
|
||||
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
|
||||
"""Process the output of policy neural network evaluation.
|
||||
|
||||
|
@ -980,11 +1350,10 @@ def _process_policy_eval_results(
|
|||
actions_to_send[env_id] = {} # at minimum send empty dict
|
||||
|
||||
# type: PolicyID, List[PolicyEvalData]
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols: StateBatch = _to_column_format(
|
||||
[t.rnn_state for t in eval_data])
|
||||
|
||||
for policy_id in to_eval:
|
||||
actions: TensorStructType = eval_results[policy_id][0]
|
||||
actions = convert_to_numpy(actions)
|
||||
|
||||
rnn_out_cols: StateBatch = eval_results[policy_id][1]
|
||||
pi_info_cols: dict = eval_results[policy_id][2]
|
||||
|
||||
|
@ -993,40 +1362,58 @@ def _process_policy_eval_results(
|
|||
if isinstance(actions, list):
|
||||
actions = np.array(actions)
|
||||
|
||||
if len(rnn_in_cols) != len(rnn_out_cols):
|
||||
raise ValueError("Length of RNN in did not match RNN out, got: "
|
||||
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# Add RNN state info.
|
||||
eval_data = None
|
||||
if not _use_trajectory_view_api:
|
||||
eval_data = to_eval[policy_id]
|
||||
rnn_in_cols: StateBatch = _to_column_format(
|
||||
[t.rnn_state for t in eval_data])
|
||||
|
||||
if len(rnn_in_cols) != len(rnn_out_cols):
|
||||
raise ValueError(
|
||||
"Length of RNN in did not match RNN out, got: "
|
||||
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
|
||||
for f_i, column in enumerate(rnn_in_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
# Split action-component batches into single action rows.
|
||||
actions: List[EnvActionType] = unbatch(actions)
|
||||
# type: int, EnvActionType
|
||||
for i, action in enumerate(actions):
|
||||
env_id: int = eval_data[i].env_id
|
||||
agent_id: AgentID = eval_data[i].agent_id
|
||||
# Clip if necessary.
|
||||
if clip_actions:
|
||||
clipped_action = clip_action(action,
|
||||
policy.action_space_struct)
|
||||
else:
|
||||
clipped_action = action
|
||||
actions_to_send[env_id][agent_id] = clipped_action
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
agent_id, {k: v[i]
|
||||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode._set_last_action(agent_id,
|
||||
off_policy_actions[env_id][agent_id])
|
||||
|
||||
# Trajectory View API: Do not store data directly in episode
|
||||
# (entire episode is stored in Trajectory and kept until
|
||||
# end of episode).
|
||||
if _use_trajectory_view_api:
|
||||
agent_id, episode_id, env_id = \
|
||||
_sample_collector.policy_sample_collectors[
|
||||
policy_id].forward_pass_index_to_agent_info[i]
|
||||
else:
|
||||
episode._set_last_action(agent_id, action)
|
||||
env_id: int = eval_data[i].env_id
|
||||
agent_id: AgentID = eval_data[i].agent_id
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
agent_id, {k: v[i]
|
||||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode._set_last_action(
|
||||
agent_id, off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode._set_last_action(agent_id, action)
|
||||
|
||||
assert agent_id not in actions_to_send[env_id]
|
||||
actions_to_send[env_id][agent_id] = clipped_action
|
||||
|
||||
return actions_to_send
|
||||
|
||||
|
@ -1054,20 +1441,21 @@ def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
|
|||
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||
|
||||
|
||||
def _get_or_raise(mapping: Dict[PolicyID, Policy],
|
||||
policy_id: PolicyID) -> Policy:
|
||||
"""Returns a Policy object under key `policy_id` in `mapping`.
|
||||
def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
|
||||
policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]:
|
||||
"""Returns an object under key `policy_id` in `mapping`.
|
||||
|
||||
Args:
|
||||
mapping (dict): The mapping dict from policy id (str) to
|
||||
actual Policy object.
|
||||
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
|
||||
mapping dict from policy id (str) to actual object (Policy,
|
||||
Preprocessor, etc.).
|
||||
policy_id (str): The policy ID to lookup.
|
||||
|
||||
Returns:
|
||||
Policy: The found Policy object.
|
||||
Union[Policy, Preprocessor, Filter]: The found object.
|
||||
|
||||
Throws:
|
||||
ValueError: If `policy_id` cannot be found.
|
||||
ValueError: If `policy_id` cannot be found in `mapping`.
|
||||
"""
|
||||
if policy_id not in mapping:
|
||||
raise ValueError(
|
||||
|
|
277
rllib/evaluation/tests/test_trajectory_view_api.py
Normal file
277
rllib/evaluation/tests/test_trajectory_view_api.py
Normal file
|
@ -0,0 +1,277 @@
|
|||
import copy
|
||||
from gym.spaces import Box, Discrete
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.examples.policy.episode_env_aware_policy import \
|
||||
EpisodeEnvAwarePolicy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_traj_view_normal_case(self):
|
||||
"""Tests, whether Model and Policy return the correct ViewRequirements.
|
||||
"""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.training_view_requirements
|
||||
assert len(view_req_model) == 1
|
||||
assert len(view_req_policy) == 10
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
SampleBatch.VF_PREDS, "advantages", "value_targets",
|
||||
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
|
||||
]:
|
||||
assert key in view_req_policy
|
||||
# None of the view cols has a special underlying data_col,
|
||||
# except next-obs.
|
||||
if key != SampleBatch.NEXT_OBS:
|
||||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_lstm_prev_actions_and_rewards(self):
|
||||
"""Tests, whether Policy/Model return correct LSTM ViewRequirements.
|
||||
"""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["model"] = config["model"].copy()
|
||||
# Activate LSTM + prev-action + rewards.
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["lstm_use_prev_action_reward"] = True
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.training_view_requirements
|
||||
assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates
|
||||
assert len(view_req_policy) == 16
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS, "advantages", "value_targets",
|
||||
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
|
||||
]:
|
||||
assert key in view_req_policy
|
||||
|
||||
if key == SampleBatch.PREV_ACTIONS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key == SampleBatch.PREV_REWARDS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key not in [
|
||||
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS
|
||||
]:
|
||||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_lstm_performance(self):
|
||||
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
|
||||
"""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
action_space = Discrete(2)
|
||||
obs_space = Box(-1.0, 1.0, shape=(700, ))
|
||||
|
||||
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
|
||||
|
||||
from ray.tune import register_env
|
||||
register_env("ma_env", lambda c: RandomMultiAgentEnv({
|
||||
"num_agents": 2,
|
||||
"p_done": 0.01,
|
||||
"action_space": action_space,
|
||||
"observation_space": obs_space
|
||||
}))
|
||||
|
||||
config["num_workers"] = 3
|
||||
config["num_envs_per_worker"] = 8
|
||||
config["num_sgd_iter"] = 6
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["lstm_use_prev_action_reward"] = True
|
||||
config["model"]["max_seq_len"] = 100
|
||||
|
||||
policies = {
|
||||
"pol0": (None, obs_space, action_space, {}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
config["multiagent"] = {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_fn,
|
||||
}
|
||||
num_iterations = 1
|
||||
# Only works in torch so far.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
print("w/ traj. view API (and time-major)")
|
||||
config["_use_trajectory_view_api"] = True
|
||||
config["model"]["_time_major"] = True
|
||||
trainer = ppo.PPOTrainer(config=config, env="ma_env")
|
||||
learn_time_w = 0.0
|
||||
sampler_perf = {}
|
||||
start = time.time()
|
||||
for i in range(num_iterations):
|
||||
out = trainer.train()
|
||||
sampler_perf_ = out["sampler_perf"]
|
||||
sampler_perf = {
|
||||
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
|
||||
for k, v in sampler_perf_.items()
|
||||
}
|
||||
delta = out["timers"]["learn_time_ms"] / 1000
|
||||
learn_time_w += delta
|
||||
print("{}={}s".format(i, delta))
|
||||
sampler_perf = {
|
||||
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf.items()
|
||||
}
|
||||
duration_w = time.time() - start
|
||||
print("Duration: {}s "
|
||||
"sampler-perf.={} learn-time/iter={}s".format(
|
||||
duration_w, sampler_perf, learn_time_w / num_iterations))
|
||||
trainer.stop()
|
||||
|
||||
print("w/o traj. view API (and w/o time-major)")
|
||||
config["_use_trajectory_view_api"] = False
|
||||
config["model"]["_time_major"] = False
|
||||
trainer = ppo.PPOTrainer(config=config, env="ma_env")
|
||||
learn_time_wo = 0.0
|
||||
sampler_perf = {}
|
||||
start = time.time()
|
||||
for i in range(num_iterations):
|
||||
out = trainer.train()
|
||||
sampler_perf_ = out["sampler_perf"]
|
||||
sampler_perf = {
|
||||
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
|
||||
for k, v in sampler_perf_.items()
|
||||
}
|
||||
delta = out["timers"]["learn_time_ms"] / 1000
|
||||
learn_time_wo += delta
|
||||
print("{}={}s".format(i, delta))
|
||||
sampler_perf = {
|
||||
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf.items()
|
||||
}
|
||||
duration_wo = time.time() - start
|
||||
print("Duration: {}s "
|
||||
"sampler-perf.={} learn-time/iter={}s".format(
|
||||
duration_wo, sampler_perf,
|
||||
learn_time_wo / num_iterations))
|
||||
trainer.stop()
|
||||
|
||||
# Assert `_use_trajectory_view_api` is much faster.
|
||||
self.assertLess(duration_w, duration_wo)
|
||||
self.assertLess(learn_time_w, learn_time_wo * 0.6)
|
||||
|
||||
def test_traj_view_lstm_functionality(self):
|
||||
action_space = Box(-float("inf"), float("inf"), shape=(2, ))
|
||||
obs_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
max_seq_len = 50
|
||||
policies = {
|
||||
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
rollout_worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config={
|
||||
"multiagent": {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_fn,
|
||||
},
|
||||
"_use_trajectory_view_api": True,
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
"_time_major": True,
|
||||
"max_seq_len": max_seq_len,
|
||||
},
|
||||
},
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
for i in range(100):
|
||||
pc = rollout_worker.sampler.sample_collector. \
|
||||
policy_sample_collectors["pol0"]
|
||||
sample_batch_offset_before = pc.sample_batch_offset
|
||||
buffers = pc.buffers
|
||||
result = rollout_worker.sample()
|
||||
pol_batch = result.policy_batches["pol0"]
|
||||
|
||||
self.assertTrue(result.count == 100)
|
||||
self.assertTrue(pol_batch.count >= 100)
|
||||
self.assertFalse(0 in pol_batch.seq_lens)
|
||||
# Check prev_reward/action, next_obs consistency.
|
||||
for t in range(max_seq_len):
|
||||
obs_t = pol_batch["obs"][t]
|
||||
r_t = pol_batch["rewards"][t]
|
||||
if t > 0:
|
||||
next_obs_t_m_1 = pol_batch["new_obs"][t - 1]
|
||||
self.assertTrue((obs_t == next_obs_t_m_1).all())
|
||||
if t < max_seq_len - 1:
|
||||
prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1]
|
||||
self.assertTrue((r_t == prev_rewards_t_p_1).all())
|
||||
|
||||
# Check the sanity of all the buffers in the un underlying
|
||||
# PerPolicy collector.
|
||||
for sample_batch_slot, agent_slot in enumerate(
|
||||
range(sample_batch_offset_before, pc.sample_batch_offset)):
|
||||
t_buf = buffers["t"][:, agent_slot]
|
||||
obs_buf = buffers["obs"][:, agent_slot]
|
||||
# Skip empty seqs at end (these won't be part of the batch
|
||||
# and have been copied to new agent-slots (even if seq-len=0)).
|
||||
if sample_batch_slot < len(pol_batch.seq_lens):
|
||||
seq_len = pol_batch.seq_lens[sample_batch_slot]
|
||||
# Make sure timesteps are always increasing within the seq.
|
||||
assert all(t_buf[1] + j == n + 1
|
||||
for j, n in enumerate(t_buf)
|
||||
if j < seq_len and j != 0)
|
||||
# Make sure all obs within seq are non-0.0.
|
||||
assert all(
|
||||
any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1))
|
||||
|
||||
# Check seq-lens.
|
||||
for agent_slot, seq_len in enumerate(pol_batch.seq_lens):
|
||||
if seq_len < max_seq_len - 1:
|
||||
# At least in the beginning, the next slots should always
|
||||
# be empty (once all agent slots have been used once, these
|
||||
# may be filled with "old" values (from longer sequences)).
|
||||
if i < 10:
|
||||
self.assertTrue(
|
||||
(pol_batch["obs"][seq_len +
|
||||
1][agent_slot] == 0.0).all())
|
||||
print(end="")
|
||||
self.assertFalse(
|
||||
(pol_batch["obs"][seq_len][agent_slot] == 0.0).all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
43
rllib/examples/env/debug_counter_env.py
vendored
43
rllib/examples/env/debug_counter_env.py
vendored
|
@ -1,4 +1,7 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class DebugCounterEnv(gym.Env):
|
||||
|
@ -21,3 +24,43 @@ class DebugCounterEnv(gym.Env):
|
|||
def step(self, action):
|
||||
self.i += 1
|
||||
return [self.i], self.i % 3, self.i >= 15, {}
|
||||
|
||||
|
||||
class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
||||
def __init__(self, config):
|
||||
self.num_agents = config["num_agents"]
|
||||
self.p_done = config.get("p_done", 0.02)
|
||||
# Actions are always:
|
||||
# (episodeID, envID) as floats.
|
||||
self.action_space = \
|
||||
gym.spaces.Box(-float("inf"), float("inf"), shape=(2, ))
|
||||
# Observation dims:
|
||||
# 0=agent ID.
|
||||
# 1=episode ID (0.0 for obs after reset).
|
||||
# 2=env ID (0.0 for obs after reset).
|
||||
# 3=ts (of the agent).
|
||||
self.observation_space = \
|
||||
gym.spaces.Box(float("-inf"), float("inf"), (4, ))
|
||||
self.timesteps = [0] * self.num_agents
|
||||
self.dones = set()
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {
|
||||
i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
for i in range(self.num_agents)
|
||||
}
|
||||
|
||||
def step(self, action_dict):
|
||||
obs, rew, done = {}, {}, {}
|
||||
for i, action in action_dict.items():
|
||||
self.timesteps[i] += 1
|
||||
obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
|
||||
rew[i] = self.timesteps[i] % 3
|
||||
done[i] = bool(
|
||||
np.random.choice(
|
||||
[True, False], p=[self.p_done, 1.0 - self.p_done]))
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
done["__all__"] = len(self.dones) == self.num_agents
|
||||
return obs, rew, done, {}
|
||||
|
|
66
rllib/examples/policy/episode_env_aware_policy.py
Normal file
66
rllib/examples/policy/episode_env_aware_policy.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import numpy as np
|
||||
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class EpisodeEnvAwarePolicy(RandomPolicy):
|
||||
"""A Policy that always knows the current EpisodeID and EnvID and
|
||||
returns these in its actions."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.episode_id = None
|
||||
self.env_id = None
|
||||
|
||||
class _fake_model:
|
||||
pass
|
||||
|
||||
self.model = _fake_model()
|
||||
self.model.time_major = True
|
||||
self.model.inference_view_requirements = {
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
"env_id": ViewRequirement(),
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
}
|
||||
self.training_view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
SampleBatch.OBS, shift=1),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
},
|
||||
**self.model.inference_view_requirements)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return True
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(self,
|
||||
input_dict,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
self.episode_id = input_dict[SampleBatch.EPS_ID][0]
|
||||
self.env_id = input_dict["env_id"][0]
|
||||
# Always return (episodeID, envID)
|
||||
return [
|
||||
np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"]
|
||||
], [], {}
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0
|
||||
return sample_batch
|
|
@ -65,6 +65,9 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
"lstm_cell_size": 256,
|
||||
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
|
||||
"lstm_use_prev_action_reward": False,
|
||||
# Experimental (only works with `_use_trajectory_view_api`=True):
|
||||
# Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
|
||||
"_time_major": False,
|
||||
# When using modelv1 models with a modelv2 algorithm, you may have to
|
||||
# define the state shape here (e.g., [256, 256]).
|
||||
"state_shape": None,
|
||||
|
|
|
@ -58,6 +58,10 @@ class ModelV2:
|
|||
self.name: str = name or "default_model"
|
||||
self.framework: str = framework
|
||||
self._last_output = None
|
||||
self.time_major = self.model_config.get("_time_major")
|
||||
self.inference_view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
@PublicAPI
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
|
@ -246,26 +250,6 @@ class ModelV2:
|
|||
i += 1
|
||||
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
||||
|
||||
def inference_view_requirements(self) -> Dict[str, ViewRequirement]:
|
||||
"""Returns a dict of ViewRequirements for this Model.
|
||||
|
||||
Note: This is an experimental API method.
|
||||
|
||||
The view requirements dict is used to generate input_dicts and
|
||||
train batches for 1) action computations, 2) postprocessing, and 3)
|
||||
generating training batches.
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The view requirements dict, mapping
|
||||
each view key (which will be available in input_dicts) to
|
||||
an underlying requirement (actual data, timestep shift, etc..).
|
||||
"""
|
||||
# Default implementation for simple RL model:
|
||||
# Single requirement: Pass current obs as input.
|
||||
return {
|
||||
SampleBatch.OBS: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
def import_from_h5(self, h5_file: str) -> None:
|
||||
"""Imports weights from an h5 file.
|
||||
|
||||
|
@ -322,6 +306,16 @@ class ModelV2:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def is_time_major(self) -> bool:
|
||||
"""If True, data for calling this ModelV2 must be in time-major format.
|
||||
|
||||
Returns
|
||||
bool: Whether this ModelV2 requires a time-major (TxBx...) data
|
||||
format.
|
||||
"""
|
||||
return self.time_major is True
|
||||
|
||||
|
||||
class NullContextManager:
|
||||
"""No-op context manager"""
|
||||
|
|
|
@ -54,10 +54,11 @@ class RecurrentNetwork(TFModelV2):
|
|||
|
||||
You should implement forward_rnn() in your subclass."""
|
||||
assert seq_lens is not None
|
||||
|
||||
padded_inputs = input_dict["obs_flat"]
|
||||
max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0]
|
||||
output, new_state = self.forward_rnn(
|
||||
add_time_dimension(
|
||||
input_dict["obs_flat"], seq_lens, framework="tf"), state,
|
||||
padded_inputs, max_seq_len=max_seq_len, framework="tf"), state,
|
||||
seq_lens)
|
||||
return tf.reshape(output, [-1, self.num_outputs]), new_state
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
|
@ -63,13 +63,20 @@ class RecurrentNetwork(TorchModelV2):
|
|||
"""Adds time dimension to batch before sending inputs to forward_rnn().
|
||||
|
||||
You should implement forward_rnn() in your subclass."""
|
||||
flat_inputs = input_dict["obs_flat"].float()
|
||||
if isinstance(seq_lens, np.ndarray):
|
||||
seq_lens = torch.Tensor(seq_lens).int()
|
||||
output, new_state = self.forward_rnn(
|
||||
add_time_dimension(
|
||||
input_dict["obs_flat"].float(), seq_lens, framework="torch"),
|
||||
state, seq_lens)
|
||||
return torch.reshape(output, [-1, self.num_outputs]), new_state
|
||||
max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
|
||||
self.time_major = self.model_config.get("_time_major", False)
|
||||
inputs = add_time_dimension(
|
||||
flat_inputs,
|
||||
max_seq_len=max_seq_len,
|
||||
framework="torch",
|
||||
time_major=self.time_major,
|
||||
)
|
||||
output, new_state = self.forward_rnn(inputs, state, seq_lens)
|
||||
output = torch.reshape(output, [-1, self.num_outputs])
|
||||
return output, new_state
|
||||
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
@ -104,13 +111,15 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
super().__init__(obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.time_major = model_config.get("_time_major", False)
|
||||
self.use_prev_action_reward = model_config[
|
||||
"lstm_use_prev_action_reward"]
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_prev_action_reward:
|
||||
self.num_outputs += 1 + self.action_dim
|
||||
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)
|
||||
self.lstm = nn.LSTM(
|
||||
self.num_outputs, self.cell_size, batch_first=not self.time_major)
|
||||
|
||||
self.num_outputs = num_outputs
|
||||
|
||||
|
@ -126,6 +135,26 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
activation_fn=None,
|
||||
initializer=torch.nn.init.xavier_uniform_)
|
||||
|
||||
self.inference_view_requirements.update(
|
||||
dict(
|
||||
**{
|
||||
SampleBatch.OBS: ViewRequirement(shift=0),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1),
|
||||
}))
|
||||
for i in range(2):
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
assert seq_lens is not None
|
||||
|
@ -150,10 +179,24 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
# Don't show paddings to RNN(?)
|
||||
# TODO: (sven) For now, only allow, iff time_major=True to not break
|
||||
# anything retrospectively (time_major not supported previously).
|
||||
# max_seq_len = inputs.shape[0]
|
||||
# time_major = self.model_config["_time_major"]
|
||||
# if time_major and max_seq_len > 1:
|
||||
# inputs = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
# inputs, seq_lens,
|
||||
# batch_first=not time_major, enforce_sorted=False)
|
||||
self._features, [h, c] = self.lstm(
|
||||
inputs,
|
||||
[torch.unsqueeze(state[0], 0),
|
||||
torch.unsqueeze(state[1], 0)])
|
||||
# Re-apply paddings.
|
||||
# if time_major and max_seq_len > 1:
|
||||
# self._features, _ = torch.nn.utils.rnn.pad_packed_sequence(
|
||||
# self._features,
|
||||
# batch_first=not time_major)
|
||||
model_out = self._logits_branch(self._features)
|
||||
return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
|
||||
|
||||
|
@ -171,16 +214,3 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
def value_function(self):
|
||||
assert self._features is not None, "must call forward() first"
|
||||
return torch.reshape(self._value_branch(self._features), [-1])
|
||||
|
||||
@override(ModelV2)
|
||||
def inference_view_requirements(self) -> Dict[str, ViewRequirement]:
|
||||
req = super().inference_view_requirements()
|
||||
# Optional: prev-actions/rewards for forward pass.
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
req.update({
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
||||
})
|
||||
return req
|
||||
|
|
|
@ -70,6 +70,11 @@ class Policy(metaclass=ABCMeta):
|
|||
# The action distribution class to use for action sampling, if any.
|
||||
# Child classes may set this.
|
||||
self.dist_class = None
|
||||
# View requirements dict for a `learn_on_batch()` call.
|
||||
# Child classes need to add their specific requirements here (usually
|
||||
# a combination of a Model's inference_view_- and the
|
||||
# Policy's loss function-requirements.
|
||||
self.training_view_requirements = {}
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
|
@ -283,25 +288,6 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def training_view_requirements(self):
|
||||
"""Returns a dict of view requirements for operating on this Policy.
|
||||
|
||||
Note: This is an experimental API method.
|
||||
|
||||
The view requirements dict is used to generate input_dicts and
|
||||
SampleBatches for 1) action computations, 2) postprocessing, and 3)
|
||||
generating training batches.
|
||||
The Policy may ask its Model(s) as well for possible additional
|
||||
requirements (e.g. prev-action/reward in an LSTM).
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The view requirements dict, mapping
|
||||
each view key (which will be available in input_dicts) to
|
||||
an underlying requirement (actual data, timestep shift, etc..).
|
||||
"""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
|
|
|
@ -13,12 +13,14 @@ current algorithms: https://github.com/ray-project/ray/issues/2992
|
|||
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.util import log_once
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
from ray.util import log_once
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -27,11 +29,14 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
def pad_batch_to_sequences_of_same_size(batch,
|
||||
max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=1,
|
||||
feature_keys=None):
|
||||
def pad_batch_to_sequences_of_same_size(
|
||||
batch: SampleBatch,
|
||||
max_seq_len: int,
|
||||
shuffle: bool = False,
|
||||
batch_divisibility_req: int = 1,
|
||||
feature_keys: Optional[List[str]] = None,
|
||||
_use_trajectory_view_api: bool = False,
|
||||
):
|
||||
"""Applies padding to `batch` so it's choppable into same-size sequences.
|
||||
|
||||
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
|
||||
|
@ -51,7 +56,26 @@ def pad_batch_to_sequences_of_same_size(batch,
|
|||
feature_keys (Optional[List[str]]): An optional list of keys to apply
|
||||
sequence-chopping to. If None, use all keys in batch that are not
|
||||
"state_in/out_"-type keys.
|
||||
_use_trajectory_view_api (bool): Whether we are using the Trajectory
|
||||
View API to collect and process samples.
|
||||
"""
|
||||
if _use_trajectory_view_api:
|
||||
if batch.time_major is not None:
|
||||
batch["seq_lens"] = torch.tensor(batch.seq_lens)
|
||||
t = 0 if batch.time_major else 1
|
||||
for col in batch.data.keys():
|
||||
# Cut time-dim from states.
|
||||
if "state_" in col[:6]:
|
||||
batch[col] = batch[col][t]
|
||||
# Flatten all other data.
|
||||
else:
|
||||
# Cut time-dim at `max_seq_len`.
|
||||
if batch.time_major:
|
||||
batch[col] = batch[col][:batch.max_seq_len]
|
||||
batch[col] = batch[col].reshape((-1, ) +
|
||||
batch[col].shape[2:])
|
||||
return
|
||||
|
||||
if batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
|
||||
|
@ -61,7 +85,7 @@ def pad_batch_to_sequences_of_same_size(batch,
|
|||
meets_divisibility_reqs = True
|
||||
|
||||
# RNN-case.
|
||||
if "state_in_0" in batch:
|
||||
if "state_in_0" in batch or "state_out_0" in batch:
|
||||
dynamic_max = True
|
||||
# Multi-agent case.
|
||||
elif not meets_divisibility_reqs:
|
||||
|
@ -109,31 +133,32 @@ def pad_batch_to_sequences_of_same_size(batch,
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
def add_time_dimension(padded_inputs,
|
||||
seq_lens,
|
||||
framework="tf",
|
||||
time_major=False):
|
||||
def add_time_dimension(padded_inputs: TensorType,
|
||||
*,
|
||||
max_seq_len: int,
|
||||
framework: str = "tf",
|
||||
time_major: bool = False):
|
||||
"""Adds a time dimension to padded inputs.
|
||||
|
||||
Arguments:
|
||||
padded_inputs (Tensor): a padded batch of sequences. That is,
|
||||
Args:
|
||||
padded_inputs (TensorType): a padded batch of sequences. That is,
|
||||
for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
|
||||
A, B, C are sequence elements and * denotes padding.
|
||||
seq_lens (Tensor): the sequence lengths within the input batch,
|
||||
suitable for passing to tf.nn.dynamic_rnn().
|
||||
max_seq_len (int): The max. sequence length in padded_inputs.
|
||||
framework (str): The framework string ("tf2", "tf", "tfe", "torch").
|
||||
time_major (bool): Whether data should be returned in time-major (TxB)
|
||||
format or not (BxT).
|
||||
|
||||
Returns:
|
||||
Reshaped tensor of shape [NUM_SEQUENCES, MAX_SEQ_LEN, ...].
|
||||
TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
|
||||
"""
|
||||
|
||||
# Sequence lengths have to be specified for LSTM batch inputs. The
|
||||
# input batch must be padded to the max seq length given here. That is,
|
||||
# batch_size == len(seq_lens) * max(seq_lens)
|
||||
if framework == "tf":
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
assert time_major is False, "time-major not supported yet for tf!"
|
||||
padded_batch_size = tf.shape(padded_inputs)[0]
|
||||
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
|
||||
|
||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||
new_batch_size = padded_batch_size // max_seq_len
|
||||
new_shape = ([new_batch_size, max_seq_len] +
|
||||
|
@ -142,7 +167,6 @@ def add_time_dimension(padded_inputs,
|
|||
else:
|
||||
assert framework == "torch", "`framework` must be either tf or torch!"
|
||||
padded_batch_size = padded_inputs.shape[0]
|
||||
max_seq_len = padded_batch_size // seq_lens.shape[0]
|
||||
|
||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||
new_batch_size = padded_batch_size // max_seq_len
|
||||
|
@ -153,6 +177,9 @@ def add_time_dimension(padded_inputs,
|
|||
return torch.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
# NOTE: This function will be deprecated once chunks already come padded and
|
||||
# correctly chopped from the _SampleCollector object (in time-major fashion
|
||||
# or not). It is already no longer user iff `_use_trajectory_view_api` = True.
|
||||
@DeveloperAPI
|
||||
def chop_into_sequences(episode_ids,
|
||||
unroll_ids,
|
||||
|
@ -166,11 +193,11 @@ def chop_into_sequences(episode_ids,
|
|||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
Args:
|
||||
episode_ids (list): List of episode ids for each step.
|
||||
unroll_ids (list): List of identifiers for the sample batch. This is
|
||||
used to make sure sequences are cut between sample batches.
|
||||
agent_indices (list): List of agent ids for each step. Note that this
|
||||
has to be combined with episode_ids for uniqueness.
|
||||
episode_ids (List[EpisodeID]): List of episode ids for each step.
|
||||
unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
|
||||
This is used to make sure sequences are cut between sample batches.
|
||||
agent_indices (List[AgentID]): List of agent ids for each step. Note
|
||||
that this has to be combined with episode_ids for uniqueness.
|
||||
feature_columns (list): List of arrays containing features.
|
||||
state_columns (list): List of arrays containing LSTM state values.
|
||||
max_seq_len (int): Max length of sequences before truncation.
|
||||
|
|
|
@ -59,19 +59,33 @@ class SampleBatch:
|
|||
def __init__(self, *args, **kwargs):
|
||||
"""Constructs a sample batch (same params as dict constructor)."""
|
||||
|
||||
self._initial_inputs = kwargs.pop("_initial_inputs", {})
|
||||
# Possible seq_lens (TxB or BxT) setup.
|
||||
self.time_major = kwargs.pop("_time_major", None)
|
||||
self.seq_lens = kwargs.pop("_seq_lens", None)
|
||||
self.max_seq_len = None
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
self.max_seq_len = max(self.seq_lens)
|
||||
|
||||
# The actual data, accessible by column name (str).
|
||||
self.data = dict(*args, **kwargs)
|
||||
|
||||
lengths = []
|
||||
for k, v in self.data.copy().items():
|
||||
assert isinstance(k, str), self
|
||||
lengths.append(len(v))
|
||||
self.data[k] = np.array(v, copy=False)
|
||||
if isinstance(v, list):
|
||||
self.data[k] = np.array(v)
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, ("data columns must be same length",
|
||||
self.data, lengths)
|
||||
self.count = lengths[0]
|
||||
assert len(set(lengths)) == 1, \
|
||||
"Data columns must be same length, but lens are {}".format(lengths)
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
self.count = sum(self.seq_lens)
|
||||
else:
|
||||
self.count = len(self.data[k])
|
||||
|
||||
# Keeps track of new columns added after initial ones.
|
||||
self.new_columns = []
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
|
@ -88,11 +102,21 @@ class SampleBatch:
|
|||
"""
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
seq_lens = []
|
||||
concat_samples = []
|
||||
for s in samples:
|
||||
if s.count > 0:
|
||||
concat_samples.append(s)
|
||||
if s.seq_lens is not None:
|
||||
seq_lens.extend(s.seq_lens)
|
||||
|
||||
out = {}
|
||||
samples = [s for s in samples if s.count > 0]
|
||||
for k in samples[0].keys():
|
||||
out[k] = concat_aligned([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
for k in concat_samples[0].keys():
|
||||
out[k] = concat_aligned(
|
||||
[s[k] for s in concat_samples],
|
||||
time_major=concat_samples[0].time_major)
|
||||
return SampleBatch(
|
||||
out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
||||
|
@ -222,8 +246,18 @@ class SampleBatch:
|
|||
SampleBatch: A new SampleBatch, which has a slice of this batch's
|
||||
data.
|
||||
"""
|
||||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
if self.time_major is not None:
|
||||
return SampleBatch(
|
||||
{k: v[:, start:end]
|
||||
for k, v in self.data.items()},
|
||||
_seq_lens=self.seq_lens[start:end],
|
||||
_time_major=self.time_major)
|
||||
else:
|
||||
return SampleBatch(
|
||||
{k: v[start:end]
|
||||
for k, v in self.data.items()},
|
||||
_seq_lens=None,
|
||||
_time_major=self.time_major)
|
||||
|
||||
@PublicAPI
|
||||
def timeslices(self, k: int) -> List["SampleBatch"]:
|
||||
|
@ -290,7 +324,7 @@ class SampleBatch:
|
|||
key (str): The key (column name) to return.
|
||||
|
||||
Returns:
|
||||
TensorType]: The data under the given key.
|
||||
TensorType: The data under the given key.
|
||||
"""
|
||||
return self.data[key]
|
||||
|
||||
|
@ -302,6 +336,8 @@ class SampleBatch:
|
|||
key (str): The column name to set a value for.
|
||||
item (TensorType): The data to insert.
|
||||
"""
|
||||
if key not in self.data:
|
||||
self.new_columns.append(key)
|
||||
self.data[key] = item
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_plain(self):
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements()
|
||||
view_req_policy = policy.training_view_requirements()
|
||||
assert len(view_req_model) == 1
|
||||
assert len(view_req_policy) == 6
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
SampleBatch.VF_PREDS
|
||||
]:
|
||||
assert key in view_req_policy
|
||||
# None of the view cols has a special underlying data_col,
|
||||
# except next-obs.
|
||||
if key != SampleBatch.NEXT_OBS:
|
||||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_lstm_prev_actions_and_rewards(self):
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["model"] = config["model"].copy()
|
||||
# Activate LSTM + prev-action + rewards.
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["lstm_use_prev_action_reward"] = True
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements()
|
||||
view_req_policy = policy.training_view_requirements()
|
||||
assert len(view_req_model) == 3 # obs, prev_a, prev_r
|
||||
assert len(view_req_policy) == 8
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS
|
||||
]:
|
||||
assert key in view_req_policy
|
||||
|
||||
if key == SampleBatch.PREV_ACTIONS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key == SampleBatch.PREV_REWARDS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key not in [
|
||||
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS
|
||||
]:
|
||||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -104,16 +104,21 @@ class TorchPolicy(Policy):
|
|||
"""
|
||||
self.framework = "torch"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
if torch.cuda.is_available() and ray.get_gpu_ids(as_str=True):
|
||||
if torch.cuda.is_available() and ray.get_gpu_ids():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.model = model.to(self.device)
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements = {
|
||||
**self.model.inference_view_requirements(),
|
||||
**self.training_view_requirements(),
|
||||
}
|
||||
self.training_view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.ACTIONS: ViewRequirement(
|
||||
space=self.action_space, shift=0),
|
||||
SampleBatch.REWARDS: ViewRequirement(shift=0),
|
||||
SampleBatch.DONES: ViewRequirement(shift=0),
|
||||
},
|
||||
**self.model.inference_view_requirements)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self.unwrapped_model = model # used to support DistributedDataParallel
|
||||
self._loss = loss
|
||||
|
@ -131,17 +136,6 @@ class TorchPolicy(Policy):
|
|||
callable(get_batch_divisibility_req) else \
|
||||
(get_batch_divisibility_req or 1)
|
||||
|
||||
@override(Policy)
|
||||
def training_view_requirements(self):
|
||||
if hasattr(self, "view_requirements"):
|
||||
return self.view_requirements
|
||||
return {
|
||||
SampleBatch.ACTIONS: ViewRequirement(
|
||||
space=self.action_space, shift=0),
|
||||
SampleBatch.REWARDS: ViewRequirement(shift=0),
|
||||
SampleBatch.DONES: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def compute_actions(
|
||||
|
@ -204,9 +198,11 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
# TODO: (sven) support RNNs w/ fast sampling.
|
||||
state_batches = []
|
||||
seq_lens = None
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_" in k[:6]
|
||||
]
|
||||
seq_lens = np.array([1] * len(input_dict["obs"])) \
|
||||
if state_batches else None
|
||||
|
||||
actions, state_out, extra_fetches, logp = \
|
||||
self._compute_action_helper(
|
||||
|
@ -340,7 +336,9 @@ class TorchPolicy(Policy):
|
|||
postprocessed_batch,
|
||||
max_seq_len=self.max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req)
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
_use_trajectory_view_api=self.config["_use_trajectory_view_api"],
|
||||
)
|
||||
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
|
@ -359,6 +357,7 @@ class TorchPolicy(Policy):
|
|||
loss_out, train_batch)
|
||||
|
||||
assert len(loss_out) == len(self._optimizers)
|
||||
|
||||
# assert not any(torch.isnan(l) for l in loss_out)
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
|
||||
|
|
|
@ -224,16 +224,13 @@ def build_torch_policy(
|
|||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
)
|
||||
|
||||
if callable(training_view_requirements_fn):
|
||||
self.training_view_requirements.update(
|
||||
training_view_requirements_fn(self))
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def training_view_requirements(self):
|
||||
req = super().training_view_requirements()
|
||||
if callable(training_view_requirements_fn):
|
||||
req.update(training_view_requirements_fn(self))
|
||||
return req
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
|
|
|
@ -29,8 +29,7 @@ class ViewRequirement:
|
|||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
space: gym.Space = None,
|
||||
shift: Union[int, List[int]] = 0,
|
||||
created_during_postprocessing: bool = False):
|
||||
shift: Union[int, List[int]] = 0):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
Args:
|
||||
|
@ -47,11 +46,8 @@ class ViewRequirement:
|
|||
Example: For a view column "obs" in an Atari framestacking
|
||||
fashion, you can set `data_col="obs"` and
|
||||
`shift=[-3, -2, -1, 0]`.
|
||||
created_during_postprocessing (bool): Whether this column only gets
|
||||
created during postprocessing.
|
||||
"""
|
||||
self.data_col = data_col
|
||||
self.space = space or gym.spaces.Box(
|
||||
float("-inf"), float("inf"), shape=())
|
||||
self.shift = shift
|
||||
self.created_during_postprocessing = created_during_postprocessing
|
||||
|
|
|
@ -63,7 +63,8 @@ def minibatches(samples, sgd_minibatch_size):
|
|||
raise NotImplementedError(
|
||||
"Minibatching not implemented for multi-agent in simple mode")
|
||||
|
||||
if "state_in_0" in samples.data:
|
||||
# Replace with `if samples.seq_lens` check.
|
||||
if "state_in_0" in samples.data or "state_out_0" in samples.data:
|
||||
if log_once("not_shuffling_rnn_data_in_simple_mode"):
|
||||
logger.warning("Not shuffling RNN data for SGD in simple mode")
|
||||
else:
|
||||
|
@ -71,9 +72,22 @@ def minibatches(samples, sgd_minibatch_size):
|
|||
|
||||
i = 0
|
||||
slices = []
|
||||
while i < samples.count:
|
||||
slices.append((i, i + sgd_minibatch_size))
|
||||
i += sgd_minibatch_size
|
||||
if samples.seq_lens:
|
||||
seq_no = 0
|
||||
while i < samples.count:
|
||||
seq_no_end = seq_no
|
||||
actual_count = 0
|
||||
while actual_count < sgd_minibatch_size and len(
|
||||
samples.seq_lens) > seq_no_end:
|
||||
actual_count += samples.seq_lens[seq_no_end]
|
||||
seq_no_end += 1
|
||||
slices.append((seq_no, seq_no_end))
|
||||
i += actual_count
|
||||
seq_no = seq_no_end
|
||||
else:
|
||||
while i < samples.count:
|
||||
slices.append((i, i + sgd_minibatch_size))
|
||||
i += sgd_minibatch_size
|
||||
random.shuffle(slices)
|
||||
|
||||
for i, j in slices:
|
||||
|
@ -100,7 +114,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
|
|||
samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)
|
||||
|
||||
fetches = {}
|
||||
for policy_id, policy in policies.items():
|
||||
for policy_id in policies.keys():
|
||||
if policy_id not in samples.policy_batches:
|
||||
continue
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ EnvID = int
|
|||
# Represents an episode id.
|
||||
EpisodeID = int
|
||||
|
||||
# Represents an "unroll" (maybe across different sub-envs in a vector env).
|
||||
UnrollID = int
|
||||
|
||||
# A dict keyed by agent ids, e.g. {"agent-1": value}.
|
||||
MultiAgentDict = Dict[AgentID, Any]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue