[RLlib] Attention Net prep PR #2: Smaller cleanups. (#12449)

This commit is contained in:
Sven Mika 2020-12-01 08:21:45 +01:00 committed by GitHub
parent e72147de38
commit 3ad9365e1d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 68 additions and 55 deletions

View file

@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):
@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, init_obs: TensorType) -> None:
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.
Since the very first observation in an environment is collected w/o
@ -48,6 +49,8 @@ class _SampleCollector(metaclass=ABCMeta):
values for.
env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent.
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): Initial observation (after env.reset()).
Examples:
@ -172,9 +175,10 @@ class _SampleCollector(metaclass=ABCMeta):
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
Returns:
Any: An ID that can be used in `build_multi_agent_batch` to
retrieve the samples that have been postprocessed as a
ready-built MultiAgentBatch.
Optional[MultiAgentBatch]: If `build` is True, the
SampleBatch or MultiAgentBatch built from `episode` (either
just from that episde or from the `_PolicyCollectorGroup`
in the `episode.batch_builder` property).
"""
raise NotImplementedError

View file

@ -52,17 +52,19 @@ class _AgentCollector:
# each time a (non-initial!) observation is added.
self.count = 0
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, init_obs: TensorType,
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the
initial observation for.
agent_index (int): Unique int index (starting from 0) for the agent
within its episode.
env_id (EnvID): The environment index (in a vectorized setup).
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
view_requirements (Dict[str, ViewRequirements])
@ -72,10 +74,15 @@ class _AgentCollector:
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id,
"t": t,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.buffers[SampleBatch.EPS_ID].append(episode_id)
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
@ -133,7 +140,7 @@ class _AgentCollector:
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.shift - \
shift = view_req.data_rel_pos - \
(1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
@ -187,7 +194,10 @@ class _AgentCollector:
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
"env_id", "t"
] else 0)
# Python primitive or dict (e.g. INFOs).
if isinstance(data, (int, float, bool, str, dict)):
self.buffers[col] = [0 for _ in range(shift)]
@ -360,7 +370,7 @@ class _SimpleListCollector(_SampleCollector):
@override(_SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
env_id: EnvID, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
# Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id)
@ -378,8 +388,9 @@ class _SimpleListCollector(_SampleCollector):
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_id=agent_id,
agent_index=episode._agent_index(agent_id),
env_id=env_id,
t=t,
init_obs=init_obs,
view_requirements=view_reqs)
@ -429,7 +440,7 @@ class _SimpleListCollector(_SampleCollector):
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
view_req.data_rel_pos - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0)

View file

@ -272,9 +272,11 @@ class RolloutWorker(ParallelIteratorWorker):
output_creator (Callable[[IOContext], OutputWriter]): Function that
returns an OutputWriter object for saving generated
experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_worker_envs (bool): If using num_envs_per_worker > 1,
whether to create those new envs in remote processes instead of
in the current process. This adds overheads, but can make sense
if your envs are expensive to step/reset (e.g., for StarCraft).
Use this cautiously, overheads are significant!
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal

View file

@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, filtered_obs)
policy_id, episode.length - 1,
filtered_obs)
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(
# Add initial obs to buffer.
_sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id, filtered_obs)
new_episode, agent_id, env_id, policy_id,
new_episode.length - 1, filtered_obs)
item = PolicyEvalData(
env_id, agent_id, filtered_obs,

View file

@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample()
expected_count = \
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
trainer.stop()
def test_traj_view_simple_performance(self):

View file

@ -28,14 +28,16 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.REWARDS, data_rel_pos=-1),
}
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i), shift=-1, space=self.state_space)
"state_out_{}".format(i),
data_rel_pos=-1,
space=self.state_space)
self.model.inference_view_requirements[
"state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space)
@ -43,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
self.view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1),
SampleBatch.OBS, data_rel_pos=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),

View file

@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
self.view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
shift=-1,
data_rel_pos=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
})

View file

@ -61,7 +61,8 @@ class ModelV2:
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
SampleBatch.OBS: ViewRequirement(
data_rel_pos=0, space=self.obs_space),
}
# TODO: (sven): Get rid of `get_initial_state` once Trajectory

View file

@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],

View file

@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],

View file

@ -564,13 +564,14 @@ class Policy(metaclass=ABCMeta):
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS,
shift=1,
data_rel_pos=1,
space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
@ -617,7 +618,7 @@ class Policy(metaclass=ABCMeta):
batch_for_postproc.count = self._dummy_batch.count
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs:
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size]
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
# TODO: (sven) This hack will not work for attention net traj.
# view setup.
i = 0
@ -657,7 +658,8 @@ class Policy(metaclass=ABCMeta):
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key in self.view_requirements:
key in self.view_requirements and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
@ -680,18 +682,6 @@ class Policy(metaclass=ABCMeta):
"postprocessing function.".format(key))
else:
del self.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(self.view_requirements.keys()):
vr = self.view_requirements[key]
if vr.data_col is not None and \
vr.data_col not in self.view_requirements:
used_for_training = \
vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = \
ViewRequirement(
space=vr.space,
used_for_training=used_for_training)
def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1) -> SampleBatch:
@ -727,7 +717,7 @@ class Policy(metaclass=ABCMeta):
model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
data_rel_pos=-1,
space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))

View file

@ -29,7 +29,7 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0,
data_rel_pos: Union[int, List[int]] = 0,
used_for_training: bool = True):
"""Initializes a ViewRequirement object.
@ -40,13 +40,14 @@ class ViewRequirement:
space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards.
shift (Union[int, List[int]]): Single shift value of list of
shift values to use relative to the underlying `data_col`.
data_rel_pos (Union[int, str, List[int]]): Single shift value or
list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set
`data_col="actions"` and `shift=-1`.
`data_col="actions"` and `data_rel_pos=-1`.
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
`data_rel_pos=[-3, -2, -1, 0]`.
used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
@ -54,5 +55,5 @@ class ViewRequirement:
self.data_col = data_col
self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.shift = shift
self.data_rel_pos = data_rel_pos
self.used_for_training = used_for_training