mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
e72147de38
commit
3ad9365e1d
12 changed files with 68 additions and 55 deletions
|
@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||||
policy_id: PolicyID, init_obs: TensorType) -> None:
|
policy_id: PolicyID, t: int,
|
||||||
|
init_obs: TensorType) -> None:
|
||||||
"""Adds an initial obs (after reset) to this collector.
|
"""Adds an initial obs (after reset) to this collector.
|
||||||
|
|
||||||
Since the very first observation in an environment is collected w/o
|
Since the very first observation in an environment is collected w/o
|
||||||
|
@ -48,6 +49,8 @@ class _SampleCollector(metaclass=ABCMeta):
|
||||||
values for.
|
values for.
|
||||||
env_id (EnvID): The environment index (in a vectorized setup).
|
env_id (EnvID): The environment index (in a vectorized setup).
|
||||||
policy_id (PolicyID): Unique id for policy controlling the agent.
|
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||||
|
t (int): The time step (episode length - 1). The initial obs has
|
||||||
|
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
||||||
init_obs (TensorType): Initial observation (after env.reset()).
|
init_obs (TensorType): Initial observation (after env.reset()).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -172,9 +175,10 @@ class _SampleCollector(metaclass=ABCMeta):
|
||||||
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
|
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: An ID that can be used in `build_multi_agent_batch` to
|
Optional[MultiAgentBatch]: If `build` is True, the
|
||||||
retrieve the samples that have been postprocessed as a
|
SampleBatch or MultiAgentBatch built from `episode` (either
|
||||||
ready-built MultiAgentBatch.
|
just from that episde or from the `_PolicyCollectorGroup`
|
||||||
|
in the `episode.batch_builder` property).
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -52,17 +52,19 @@ class _AgentCollector:
|
||||||
# each time a (non-initial!) observation is added.
|
# each time a (non-initial!) observation is added.
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
|
||||||
env_id: EnvID, init_obs: TensorType,
|
env_id: EnvID, t: int, init_obs: TensorType,
|
||||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||||
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
||||||
initial observation for.
|
initial observation for.
|
||||||
agent_id (AgentID): Unique ID for the agent we are adding the
|
agent_index (int): Unique int index (starting from 0) for the agent
|
||||||
initial observation for.
|
within its episode.
|
||||||
env_id (EnvID): The environment index (in a vectorized setup).
|
env_id (EnvID): The environment index (in a vectorized setup).
|
||||||
|
t (int): The time step (episode length - 1). The initial obs has
|
||||||
|
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
||||||
init_obs (TensorType): The initial observation tensor (after
|
init_obs (TensorType): The initial observation tensor (after
|
||||||
`env.reset()`).
|
`env.reset()`).
|
||||||
view_requirements (Dict[str, ViewRequirements])
|
view_requirements (Dict[str, ViewRequirements])
|
||||||
|
@ -72,10 +74,15 @@ class _AgentCollector:
|
||||||
single_row={
|
single_row={
|
||||||
SampleBatch.OBS: init_obs,
|
SampleBatch.OBS: init_obs,
|
||||||
SampleBatch.EPS_ID: episode_id,
|
SampleBatch.EPS_ID: episode_id,
|
||||||
SampleBatch.AGENT_INDEX: agent_id,
|
SampleBatch.AGENT_INDEX: agent_index,
|
||||||
"env_id": env_id,
|
"env_id": env_id,
|
||||||
|
"t": t,
|
||||||
})
|
})
|
||||||
self.buffers[SampleBatch.OBS].append(init_obs)
|
self.buffers[SampleBatch.OBS].append(init_obs)
|
||||||
|
self.buffers[SampleBatch.EPS_ID].append(episode_id)
|
||||||
|
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
|
||||||
|
self.buffers["env_id"].append(env_id)
|
||||||
|
self.buffers["t"].append(t)
|
||||||
|
|
||||||
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
|
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
|
||||||
None:
|
None:
|
||||||
|
@ -133,7 +140,7 @@ class _AgentCollector:
|
||||||
continue
|
continue
|
||||||
# OBS are already shifted by -1 (the initial obs starts one ts
|
# OBS are already shifted by -1 (the initial obs starts one ts
|
||||||
# before all other data columns).
|
# before all other data columns).
|
||||||
shift = view_req.shift - \
|
shift = view_req.data_rel_pos - \
|
||||||
(1 if data_col == SampleBatch.OBS else 0)
|
(1 if data_col == SampleBatch.OBS else 0)
|
||||||
if data_col not in np_data:
|
if data_col not in np_data:
|
||||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||||
|
@ -187,7 +194,10 @@ class _AgentCollector:
|
||||||
for col, data in single_row.items():
|
for col, data in single_row.items():
|
||||||
if col in self.buffers:
|
if col in self.buffers:
|
||||||
continue
|
continue
|
||||||
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
|
shift = self.shift_before - (1 if col in [
|
||||||
|
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||||
|
"env_id", "t"
|
||||||
|
] else 0)
|
||||||
# Python primitive or dict (e.g. INFOs).
|
# Python primitive or dict (e.g. INFOs).
|
||||||
if isinstance(data, (int, float, bool, str, dict)):
|
if isinstance(data, (int, float, bool, str, dict)):
|
||||||
self.buffers[col] = [0 for _ in range(shift)]
|
self.buffers[col] = [0 for _ in range(shift)]
|
||||||
|
@ -360,7 +370,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
|
|
||||||
@override(_SampleCollector)
|
@override(_SampleCollector)
|
||||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||||
env_id: EnvID, policy_id: PolicyID,
|
env_id: EnvID, policy_id: PolicyID, t: int,
|
||||||
init_obs: TensorType) -> None:
|
init_obs: TensorType) -> None:
|
||||||
# Make sure our mappings are up to date.
|
# Make sure our mappings are up to date.
|
||||||
agent_key = (episode.episode_id, agent_id)
|
agent_key = (episode.episode_id, agent_id)
|
||||||
|
@ -378,8 +388,9 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
self.agent_collectors[agent_key] = _AgentCollector()
|
self.agent_collectors[agent_key] = _AgentCollector()
|
||||||
self.agent_collectors[agent_key].add_init_obs(
|
self.agent_collectors[agent_key].add_init_obs(
|
||||||
episode_id=episode.episode_id,
|
episode_id=episode.episode_id,
|
||||||
agent_id=agent_id,
|
agent_index=episode._agent_index(agent_id),
|
||||||
env_id=env_id,
|
env_id=env_id,
|
||||||
|
t=t,
|
||||||
init_obs=init_obs,
|
init_obs=init_obs,
|
||||||
view_requirements=view_reqs)
|
view_requirements=view_reqs)
|
||||||
|
|
||||||
|
@ -429,7 +440,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
# Create the batch of data from the different buffers.
|
# Create the batch of data from the different buffers.
|
||||||
data_col = view_req.data_col or view_col
|
data_col = view_req.data_col or view_col
|
||||||
time_indices = \
|
time_indices = \
|
||||||
view_req.shift - (
|
view_req.data_rel_pos - (
|
||||||
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
||||||
SampleBatch.EPS_ID,
|
SampleBatch.EPS_ID,
|
||||||
SampleBatch.AGENT_INDEX] else 0)
|
SampleBatch.AGENT_INDEX] else 0)
|
||||||
|
|
|
@ -272,9 +272,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
output_creator (Callable[[IOContext], OutputWriter]): Function that
|
output_creator (Callable[[IOContext], OutputWriter]): Function that
|
||||||
returns an OutputWriter object for saving generated
|
returns an OutputWriter object for saving generated
|
||||||
experiences.
|
experiences.
|
||||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
remote_worker_envs (bool): If using num_envs_per_worker > 1,
|
||||||
those new envs in remote processes instead of in the current
|
whether to create those new envs in remote processes instead of
|
||||||
process. This adds overheads, but can make sense if your envs
|
in the current process. This adds overheads, but can make sense
|
||||||
|
if your envs are expensive to step/reset (e.g., for StarCraft).
|
||||||
|
Use this cautiously, overheads are significant!
|
||||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||||
are waiting when polling environments. 0 (continue when at
|
are waiting when polling environments. 0 (continue when at
|
||||||
least one env is ready) is a reasonable default, but optimal
|
least one env is ready) is a reasonable default, but optimal
|
||||||
|
|
|
@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
|
||||||
# Record transition info if applicable.
|
# Record transition info if applicable.
|
||||||
if last_observation is None:
|
if last_observation is None:
|
||||||
_sample_collector.add_init_obs(episode, agent_id, env_id,
|
_sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||||
policy_id, filtered_obs)
|
policy_id, episode.length - 1,
|
||||||
|
filtered_obs)
|
||||||
else:
|
else:
|
||||||
# Add actions, rewards, next-obs to collectors.
|
# Add actions, rewards, next-obs to collectors.
|
||||||
values_dict = {
|
values_dict = {
|
||||||
|
@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(
|
||||||
|
|
||||||
# Add initial obs to buffer.
|
# Add initial obs to buffer.
|
||||||
_sample_collector.add_init_obs(
|
_sample_collector.add_init_obs(
|
||||||
new_episode, agent_id, env_id, policy_id, filtered_obs)
|
new_episode, agent_id, env_id, policy_id,
|
||||||
|
new_episode.length - 1, filtered_obs)
|
||||||
|
|
||||||
item = PolicyEvalData(
|
item = PolicyEvalData(
|
||||||
env_id, agent_id, filtered_obs,
|
env_id, agent_id, filtered_obs,
|
||||||
|
|
|
@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
assert view_req_policy[key].data_col is None
|
assert view_req_policy[key].data_col is None
|
||||||
else:
|
else:
|
||||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||||
assert view_req_policy[key].shift == 1
|
assert view_req_policy[key].data_rel_pos == 1
|
||||||
rollout_worker = trainer.workers.local_worker()
|
rollout_worker = trainer.workers.local_worker()
|
||||||
sample_batch = rollout_worker.sample()
|
sample_batch = rollout_worker.sample()
|
||||||
expected_count = \
|
expected_count = \
|
||||||
|
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
|
|
||||||
if key == SampleBatch.PREV_ACTIONS:
|
if key == SampleBatch.PREV_ACTIONS:
|
||||||
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
||||||
assert view_req_policy[key].shift == -1
|
assert view_req_policy[key].data_rel_pos == -1
|
||||||
elif key == SampleBatch.PREV_REWARDS:
|
elif key == SampleBatch.PREV_REWARDS:
|
||||||
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
||||||
assert view_req_policy[key].shift == -1
|
assert view_req_policy[key].data_rel_pos == -1
|
||||||
elif key not in [
|
elif key not in [
|
||||||
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
||||||
SampleBatch.PREV_REWARDS
|
SampleBatch.PREV_REWARDS
|
||||||
|
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
assert view_req_policy[key].data_col is None
|
assert view_req_policy[key].data_col is None
|
||||||
else:
|
else:
|
||||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||||
assert view_req_policy[key].shift == 1
|
assert view_req_policy[key].data_rel_pos == 1
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
def test_traj_view_simple_performance(self):
|
def test_traj_view_simple_performance(self):
|
||||||
|
|
|
@ -28,14 +28,16 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
||||||
"t": ViewRequirement(),
|
"t": ViewRequirement(),
|
||||||
SampleBatch.OBS: ViewRequirement(),
|
SampleBatch.OBS: ViewRequirement(),
|
||||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
|
||||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||||
SampleBatch.REWARDS, shift=-1),
|
SampleBatch.REWARDS, data_rel_pos=-1),
|
||||||
}
|
}
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||||
ViewRequirement(
|
ViewRequirement(
|
||||||
"state_out_{}".format(i), shift=-1, space=self.state_space)
|
"state_out_{}".format(i),
|
||||||
|
data_rel_pos=-1,
|
||||||
|
space=self.state_space)
|
||||||
self.model.inference_view_requirements[
|
self.model.inference_view_requirements[
|
||||||
"state_out_{}".format(i)] = \
|
"state_out_{}".format(i)] = \
|
||||||
ViewRequirement(space=self.state_space)
|
ViewRequirement(space=self.state_space)
|
||||||
|
@ -43,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
||||||
self.view_requirements = dict(
|
self.view_requirements = dict(
|
||||||
**{
|
**{
|
||||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||||
SampleBatch.OBS, shift=1),
|
SampleBatch.OBS, data_rel_pos=1),
|
||||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||||
SampleBatch.REWARDS: ViewRequirement(),
|
SampleBatch.REWARDS: ViewRequirement(),
|
||||||
SampleBatch.DONES: ViewRequirement(),
|
SampleBatch.DONES: ViewRequirement(),
|
||||||
|
|
|
@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
|
||||||
self.view_requirements.update({
|
self.view_requirements.update({
|
||||||
"state_in_0": ViewRequirement(
|
"state_in_0": ViewRequirement(
|
||||||
"state_out_0",
|
"state_out_0",
|
||||||
shift=-1,
|
data_rel_pos=-1,
|
||||||
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,8 @@ class ModelV2:
|
||||||
self.time_major = self.model_config.get("_time_major")
|
self.time_major = self.model_config.get("_time_major")
|
||||||
# Basic view requirement for all models: Use the observation as input.
|
# Basic view requirement for all models: Use the observation as input.
|
||||||
self.inference_view_requirements = {
|
self.inference_view_requirements = {
|
||||||
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
|
SampleBatch.OBS: ViewRequirement(
|
||||||
|
data_rel_pos=0, space=self.obs_space),
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
||||||
|
|
|
@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
|
||||||
if model_config["lstm_use_prev_action"]:
|
if model_config["lstm_use_prev_action"]:
|
||||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||||
shift=-1)
|
data_rel_pos=-1)
|
||||||
if model_config["lstm_use_prev_reward"]:
|
if model_config["lstm_use_prev_reward"]:
|
||||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||||
|
|
||||||
@override(RecurrentNetwork)
|
@override(RecurrentNetwork)
|
||||||
def forward(self, input_dict: Dict[str, TensorType],
|
def forward(self, input_dict: Dict[str, TensorType],
|
||||||
|
|
|
@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
||||||
if model_config["lstm_use_prev_action"]:
|
if model_config["lstm_use_prev_action"]:
|
||||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||||
shift=-1)
|
data_rel_pos=-1)
|
||||||
if model_config["lstm_use_prev_reward"]:
|
if model_config["lstm_use_prev_reward"]:
|
||||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||||
|
|
||||||
@override(RecurrentNetwork)
|
@override(RecurrentNetwork)
|
||||||
def forward(self, input_dict: Dict[str, TensorType],
|
def forward(self, input_dict: Dict[str, TensorType],
|
||||||
|
|
|
@ -564,13 +564,14 @@ class Policy(metaclass=ABCMeta):
|
||||||
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
||||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||||
data_col=SampleBatch.OBS,
|
data_col=SampleBatch.OBS,
|
||||||
shift=1,
|
data_rel_pos=1,
|
||||||
space=self.observation_space),
|
space=self.observation_space),
|
||||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||||
SampleBatch.REWARDS: ViewRequirement(),
|
SampleBatch.REWARDS: ViewRequirement(),
|
||||||
SampleBatch.DONES: ViewRequirement(),
|
SampleBatch.DONES: ViewRequirement(),
|
||||||
SampleBatch.INFOS: ViewRequirement(),
|
SampleBatch.INFOS: ViewRequirement(),
|
||||||
SampleBatch.EPS_ID: ViewRequirement(),
|
SampleBatch.EPS_ID: ViewRequirement(),
|
||||||
|
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||||
"t": ViewRequirement(),
|
"t": ViewRequirement(),
|
||||||
|
@ -617,7 +618,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
batch_for_postproc.count = self._dummy_batch.count
|
batch_for_postproc.count = self._dummy_batch.count
|
||||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||||
if state_outs:
|
if state_outs:
|
||||||
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size]
|
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
||||||
# TODO: (sven) This hack will not work for attention net traj.
|
# TODO: (sven) This hack will not work for attention net traj.
|
||||||
# view setup.
|
# view setup.
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -657,7 +658,8 @@ class Policy(metaclass=ABCMeta):
|
||||||
# Tag those only needed for post-processing.
|
# Tag those only needed for post-processing.
|
||||||
for key in batch_for_postproc.accessed_keys:
|
for key in batch_for_postproc.accessed_keys:
|
||||||
if key not in train_batch.accessed_keys and \
|
if key not in train_batch.accessed_keys and \
|
||||||
key in self.view_requirements:
|
key in self.view_requirements and \
|
||||||
|
key not in self.model.inference_view_requirements:
|
||||||
self.view_requirements[key].used_for_training = False
|
self.view_requirements[key].used_for_training = False
|
||||||
# Remove those not needed at all (leave those that are needed
|
# Remove those not needed at all (leave those that are needed
|
||||||
# by Sampler to properly execute sample collection).
|
# by Sampler to properly execute sample collection).
|
||||||
|
@ -680,18 +682,6 @@ class Policy(metaclass=ABCMeta):
|
||||||
"postprocessing function.".format(key))
|
"postprocessing function.".format(key))
|
||||||
else:
|
else:
|
||||||
del self.view_requirements[key]
|
del self.view_requirements[key]
|
||||||
# Add those data_cols (again) that are missing and have
|
|
||||||
# dependencies by view_cols.
|
|
||||||
for key in list(self.view_requirements.keys()):
|
|
||||||
vr = self.view_requirements[key]
|
|
||||||
if vr.data_col is not None and \
|
|
||||||
vr.data_col not in self.view_requirements:
|
|
||||||
used_for_training = \
|
|
||||||
vr.data_col in train_batch.accessed_keys
|
|
||||||
self.view_requirements[vr.data_col] = \
|
|
||||||
ViewRequirement(
|
|
||||||
space=vr.space,
|
|
||||||
used_for_training=used_for_training)
|
|
||||||
|
|
||||||
def _get_dummy_batch_from_view_requirements(
|
def _get_dummy_batch_from_view_requirements(
|
||||||
self, batch_size: int = 1) -> SampleBatch:
|
self, batch_size: int = 1) -> SampleBatch:
|
||||||
|
@ -727,7 +717,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
model.inference_view_requirements["state_in_{}".format(i)] = \
|
model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||||
ViewRequirement(
|
ViewRequirement(
|
||||||
"state_out_{}".format(i),
|
"state_out_{}".format(i),
|
||||||
shift=-1,
|
data_rel_pos=-1,
|
||||||
space=Box(-1.0, 1.0, shape=state.shape))
|
space=Box(-1.0, 1.0, shape=state.shape))
|
||||||
model.inference_view_requirements["state_out_{}".format(i)] = \
|
model.inference_view_requirements["state_out_{}".format(i)] = \
|
||||||
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
||||||
|
|
|
@ -29,7 +29,7 @@ class ViewRequirement:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data_col: Optional[str] = None,
|
data_col: Optional[str] = None,
|
||||||
space: gym.Space = None,
|
space: gym.Space = None,
|
||||||
shift: Union[int, List[int]] = 0,
|
data_rel_pos: Union[int, List[int]] = 0,
|
||||||
used_for_training: bool = True):
|
used_for_training: bool = True):
|
||||||
"""Initializes a ViewRequirement object.
|
"""Initializes a ViewRequirement object.
|
||||||
|
|
||||||
|
@ -40,13 +40,14 @@ class ViewRequirement:
|
||||||
space (gym.Space): The gym Space used in case we need to pad data
|
space (gym.Space): The gym Space used in case we need to pad data
|
||||||
in inaccessible areas of the trajectory (t<0 or t>H).
|
in inaccessible areas of the trajectory (t<0 or t>H).
|
||||||
Default: Simple box space, e.g. rewards.
|
Default: Simple box space, e.g. rewards.
|
||||||
shift (Union[int, List[int]]): Single shift value of list of
|
data_rel_pos (Union[int, str, List[int]]): Single shift value or
|
||||||
shift values to use relative to the underlying `data_col`.
|
list of relative positions to use (relative to the underlying
|
||||||
|
`data_col`).
|
||||||
Example: For a view column "prev_actions", you can set
|
Example: For a view column "prev_actions", you can set
|
||||||
`data_col="actions"` and `shift=-1`.
|
`data_col="actions"` and `data_rel_pos=-1`.
|
||||||
Example: For a view column "obs" in an Atari framestacking
|
Example: For a view column "obs" in an Atari framestacking
|
||||||
fashion, you can set `data_col="obs"` and
|
fashion, you can set `data_col="obs"` and
|
||||||
`shift=[-3, -2, -1, 0]`.
|
`data_rel_pos=[-3, -2, -1, 0]`.
|
||||||
used_for_training (bool): Whether the data will be used for
|
used_for_training (bool): Whether the data will be used for
|
||||||
training. If False, the column will not be copied into the
|
training. If False, the column will not be copied into the
|
||||||
final train batch.
|
final train batch.
|
||||||
|
@ -54,5 +55,5 @@ class ViewRequirement:
|
||||||
self.data_col = data_col
|
self.data_col = data_col
|
||||||
self.space = space or gym.spaces.Box(
|
self.space = space or gym.spaces.Box(
|
||||||
float("-inf"), float("inf"), shape=())
|
float("-inf"), float("inf"), shape=())
|
||||||
self.shift = shift
|
self.data_rel_pos = data_rel_pos
|
||||||
self.used_for_training = used_for_training
|
self.used_for_training = used_for_training
|
||||||
|
|
Loading…
Add table
Reference in a new issue