[RLlib] Attention Net prep PR #2: Smaller cleanups. (#12449)

This commit is contained in:
Sven Mika 2020-12-01 08:21:45 +01:00 committed by GitHub
parent e72147de38
commit 3ad9365e1d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 68 additions and 55 deletions

View file

@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, init_obs: TensorType) -> None: policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector. """Adds an initial obs (after reset) to this collector.
Since the very first observation in an environment is collected w/o Since the very first observation in an environment is collected w/o
@ -48,6 +49,8 @@ class _SampleCollector(metaclass=ABCMeta):
values for. values for.
env_id (EnvID): The environment index (in a vectorized setup). env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent. policy_id (PolicyID): Unique id for policy controlling the agent.
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): Initial observation (after env.reset()). init_obs (TensorType): Initial observation (after env.reset()).
Examples: Examples:
@ -172,9 +175,10 @@ class _SampleCollector(metaclass=ABCMeta):
MultiAgentBatch. Used for batch_mode=`complete_episodes`. MultiAgentBatch. Used for batch_mode=`complete_episodes`.
Returns: Returns:
Any: An ID that can be used in `build_multi_agent_batch` to Optional[MultiAgentBatch]: If `build` is True, the
retrieve the samples that have been postprocessed as a SampleBatch or MultiAgentBatch built from `episode` (either
ready-built MultiAgentBatch. just from that episde or from the `_PolicyCollectorGroup`
in the `episode.batch_builder` property).
""" """
raise NotImplementedError raise NotImplementedError

View file

@ -52,17 +52,19 @@ class _AgentCollector:
# each time a (non-initial!) observation is added. # each time a (non-initial!) observation is added.
self.count = 0 self.count = 0
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, init_obs: TensorType, env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None: view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory. """Adds an initial observation (after reset) to the Agent's trajectory.
Args: Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for. initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the agent_index (int): Unique int index (starting from 0) for the agent
initial observation for. within its episode.
env_id (EnvID): The environment index (in a vectorized setup). env_id (EnvID): The environment index (in a vectorized setup).
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): The initial observation tensor (after init_obs (TensorType): The initial observation tensor (after
`env.reset()`). `env.reset()`).
view_requirements (Dict[str, ViewRequirements]) view_requirements (Dict[str, ViewRequirements])
@ -72,10 +74,15 @@ class _AgentCollector:
single_row={ single_row={
SampleBatch.OBS: init_obs, SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id, SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id, SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id, "env_id": env_id,
"t": t,
}) })
self.buffers[SampleBatch.OBS].append(init_obs) self.buffers[SampleBatch.OBS].append(init_obs)
self.buffers[SampleBatch.EPS_ID].append(episode_id)
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None: None:
@ -133,7 +140,7 @@ class _AgentCollector:
continue continue
# OBS are already shifted by -1 (the initial obs starts one ts # OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns). # before all other data columns).
shift = view_req.shift - \ shift = view_req.data_rel_pos - \
(1 if data_col == SampleBatch.OBS else 0) (1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data: if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col]) np_data[data_col] = to_float_np_array(self.buffers[data_col])
@ -187,7 +194,10 @@ class _AgentCollector:
for col, data in single_row.items(): for col, data in single_row.items():
if col in self.buffers: if col in self.buffers:
continue continue
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0) shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
"env_id", "t"
] else 0)
# Python primitive or dict (e.g. INFOs). # Python primitive or dict (e.g. INFOs).
if isinstance(data, (int, float, bool, str, dict)): if isinstance(data, (int, float, bool, str, dict)):
self.buffers[col] = [0 for _ in range(shift)] self.buffers[col] = [0 for _ in range(shift)]
@ -360,7 +370,7 @@ class _SimpleListCollector(_SampleCollector):
@override(_SampleCollector) @override(_SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID, env_id: EnvID, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None: init_obs: TensorType) -> None:
# Make sure our mappings are up to date. # Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id) agent_key = (episode.episode_id, agent_id)
@ -378,8 +388,9 @@ class _SimpleListCollector(_SampleCollector):
self.agent_collectors[agent_key] = _AgentCollector() self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key].add_init_obs( self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id, episode_id=episode.episode_id,
agent_id=agent_id, agent_index=episode._agent_index(agent_id),
env_id=env_id, env_id=env_id,
t=t,
init_obs=init_obs, init_obs=init_obs,
view_requirements=view_reqs) view_requirements=view_reqs)
@ -429,7 +440,7 @@ class _SimpleListCollector(_SampleCollector):
# Create the batch of data from the different buffers. # Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col data_col = view_req.data_col or view_col
time_indices = \ time_indices = \
view_req.shift - ( view_req.data_rel_pos - (
1 if data_col in [SampleBatch.OBS, "t", "env_id", 1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID, SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0) SampleBatch.AGENT_INDEX] else 0)

View file

@ -272,9 +272,11 @@ class RolloutWorker(ParallelIteratorWorker):
output_creator (Callable[[IOContext], OutputWriter]): Function that output_creator (Callable[[IOContext], OutputWriter]): Function that
returns an OutputWriter object for saving generated returns an OutputWriter object for saving generated
experiences. experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create remote_worker_envs (bool): If using num_envs_per_worker > 1,
those new envs in remote processes instead of in the current whether to create those new envs in remote processes instead of
process. This adds overheads, but can make sense if your envs in the current process. This adds overheads, but can make sense
if your envs are expensive to step/reset (e.g., for StarCraft).
Use this cautiously, overheads are significant!
remote_env_batch_wait_ms (float): Timeout that remote workers remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal least one env is ready) is a reasonable default, but optimal

View file

@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable. # Record transition info if applicable.
if last_observation is None: if last_observation is None:
_sample_collector.add_init_obs(episode, agent_id, env_id, _sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, filtered_obs) policy_id, episode.length - 1,
filtered_obs)
else: else:
# Add actions, rewards, next-obs to collectors. # Add actions, rewards, next-obs to collectors.
values_dict = { values_dict = {
@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(
# Add initial obs to buffer. # Add initial obs to buffer.
_sample_collector.add_init_obs( _sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id, filtered_obs) new_episode, agent_id, env_id, policy_id,
new_episode.length - 1, filtered_obs)
item = PolicyEvalData( item = PolicyEvalData(
env_id, agent_id, filtered_obs, env_id, agent_id, filtered_obs,

View file

@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None assert view_req_policy[key].data_col is None
else: else:
assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1 assert view_req_policy[key].data_rel_pos == 1
rollout_worker = trainer.workers.local_worker() rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample() sample_batch = rollout_worker.sample()
expected_count = \ expected_count = \
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
if key == SampleBatch.PREV_ACTIONS: if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1 assert view_req_policy[key].data_rel_pos == -1
elif key == SampleBatch.PREV_REWARDS: elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1 assert view_req_policy[key].data_rel_pos == -1
elif key not in [ elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS SampleBatch.PREV_REWARDS
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None assert view_req_policy[key].data_col is None
else: else:
assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1 assert view_req_policy[key].data_rel_pos == 1
trainer.stop() trainer.stop()
def test_traj_view_simple_performance(self): def test_traj_view_simple_performance(self):

View file

@ -28,14 +28,16 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
"t": ViewRequirement(), "t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(), SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement( SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1), SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
SampleBatch.PREV_REWARDS: ViewRequirement( SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1), SampleBatch.REWARDS, data_rel_pos=-1),
} }
for i in range(2): for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \ self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement( ViewRequirement(
"state_out_{}".format(i), shift=-1, space=self.state_space) "state_out_{}".format(i),
data_rel_pos=-1,
space=self.state_space)
self.model.inference_view_requirements[ self.model.inference_view_requirements[
"state_out_{}".format(i)] = \ "state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space) ViewRequirement(space=self.state_space)
@ -43,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
self.view_requirements = dict( self.view_requirements = dict(
**{ **{
SampleBatch.NEXT_OBS: ViewRequirement( SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1), SampleBatch.OBS, data_rel_pos=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(), SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(), SampleBatch.DONES: ViewRequirement(),

View file

@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
self.view_requirements.update({ self.view_requirements.update({
"state_in_0": ViewRequirement( "state_in_0": ViewRequirement(
"state_out_0", "state_out_0",
shift=-1, data_rel_pos=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32)) space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
}) })

View file

@ -61,7 +61,8 @@ class ModelV2:
self.time_major = self.model_config.get("_time_major") self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input. # Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = { self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space), SampleBatch.OBS: ViewRequirement(
data_rel_pos=0, space=self.obs_space),
} }
# TODO: (sven): Get rid of `get_initial_state` once Trajectory # TODO: (sven): Get rid of `get_initial_state` once Trajectory

View file

@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
if model_config["lstm_use_prev_action"]: if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1) data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]: if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1) ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
@override(RecurrentNetwork) @override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType], def forward(self, input_dict: Dict[str, TensorType],

View file

@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
if model_config["lstm_use_prev_action"]: if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1) data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]: if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1) ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
@override(RecurrentNetwork) @override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType], def forward(self, input_dict: Dict[str, TensorType],

View file

@ -564,13 +564,14 @@ class Policy(metaclass=ABCMeta):
SampleBatch.OBS: ViewRequirement(space=self.observation_space), SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement( SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS, data_col=SampleBatch.OBS,
shift=1, data_rel_pos=1,
space=self.observation_space), space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(), SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(), SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(), SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(), SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(), "t": ViewRequirement(),
@ -617,7 +618,7 @@ class Policy(metaclass=ABCMeta):
batch_for_postproc.count = self._dummy_batch.count batch_for_postproc.count = self._dummy_batch.count
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs: if state_outs:
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size] B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
# TODO: (sven) This hack will not work for attention net traj. # TODO: (sven) This hack will not work for attention net traj.
# view setup. # view setup.
i = 0 i = 0
@ -657,7 +658,8 @@ class Policy(metaclass=ABCMeta):
# Tag those only needed for post-processing. # Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys: for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \ if key not in train_batch.accessed_keys and \
key in self.view_requirements: key in self.view_requirements and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed # Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection). # by Sampler to properly execute sample collection).
@ -680,18 +682,6 @@ class Policy(metaclass=ABCMeta):
"postprocessing function.".format(key)) "postprocessing function.".format(key))
else: else:
del self.view_requirements[key] del self.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(self.view_requirements.keys()):
vr = self.view_requirements[key]
if vr.data_col is not None and \
vr.data_col not in self.view_requirements:
used_for_training = \
vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = \
ViewRequirement(
space=vr.space,
used_for_training=used_for_training)
def _get_dummy_batch_from_view_requirements( def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1) -> SampleBatch: self, batch_size: int = 1) -> SampleBatch:
@ -727,7 +717,7 @@ class Policy(metaclass=ABCMeta):
model.inference_view_requirements["state_in_{}".format(i)] = \ model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement( ViewRequirement(
"state_out_{}".format(i), "state_out_{}".format(i),
shift=-1, data_rel_pos=-1,
space=Box(-1.0, 1.0, shape=state.shape)) space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \ model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape)) ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))

View file

@ -29,7 +29,7 @@ class ViewRequirement:
def __init__(self, def __init__(self,
data_col: Optional[str] = None, data_col: Optional[str] = None,
space: gym.Space = None, space: gym.Space = None,
shift: Union[int, List[int]] = 0, data_rel_pos: Union[int, List[int]] = 0,
used_for_training: bool = True): used_for_training: bool = True):
"""Initializes a ViewRequirement object. """Initializes a ViewRequirement object.
@ -40,13 +40,14 @@ class ViewRequirement:
space (gym.Space): The gym Space used in case we need to pad data space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H). in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards. Default: Simple box space, e.g. rewards.
shift (Union[int, List[int]]): Single shift value of list of data_rel_pos (Union[int, str, List[int]]): Single shift value or
shift values to use relative to the underlying `data_col`. list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set Example: For a view column "prev_actions", you can set
`data_col="actions"` and `shift=-1`. `data_col="actions"` and `data_rel_pos=-1`.
Example: For a view column "obs" in an Atari framestacking Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`. `data_rel_pos=[-3, -2, -1, 0]`.
used_for_training (bool): Whether the data will be used for used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the training. If False, the column will not be copied into the
final train batch. final train batch.
@ -54,5 +55,5 @@ class ViewRequirement:
self.data_col = data_col self.data_col = data_col
self.space = space or gym.spaces.Box( self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=()) float("-inf"), float("inf"), shape=())
self.shift = shift self.data_rel_pos = data_rel_pos
self.used_for_training = used_for_training self.used_for_training = used_for_training