mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
e72147de38
commit
3ad9365e1d
12 changed files with 68 additions and 55 deletions
|
@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
policy_id: PolicyID, init_obs: TensorType) -> None:
|
||||
policy_id: PolicyID, t: int,
|
||||
init_obs: TensorType) -> None:
|
||||
"""Adds an initial obs (after reset) to this collector.
|
||||
|
||||
Since the very first observation in an environment is collected w/o
|
||||
|
@ -48,6 +49,8 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
values for.
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||
t (int): The time step (episode length - 1). The initial obs has
|
||||
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
||||
init_obs (TensorType): Initial observation (after env.reset()).
|
||||
|
||||
Examples:
|
||||
|
@ -172,9 +175,10 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
|
||||
|
||||
Returns:
|
||||
Any: An ID that can be used in `build_multi_agent_batch` to
|
||||
retrieve the samples that have been postprocessed as a
|
||||
ready-built MultiAgentBatch.
|
||||
Optional[MultiAgentBatch]: If `build` is True, the
|
||||
SampleBatch or MultiAgentBatch built from `episode` (either
|
||||
just from that episde or from the `_PolicyCollectorGroup`
|
||||
in the `episode.batch_builder` property).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -52,17 +52,19 @@ class _AgentCollector:
|
|||
# each time a (non-initial!) observation is added.
|
||||
self.count = 0
|
||||
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
||||
env_id: EnvID, init_obs: TensorType,
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
|
||||
env_id: EnvID, t: int, init_obs: TensorType,
|
||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
||||
initial observation for.
|
||||
agent_id (AgentID): Unique ID for the agent we are adding the
|
||||
initial observation for.
|
||||
agent_index (int): Unique int index (starting from 0) for the agent
|
||||
within its episode.
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
t (int): The time step (episode length - 1). The initial obs has
|
||||
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
||||
init_obs (TensorType): The initial observation tensor (after
|
||||
`env.reset()`).
|
||||
view_requirements (Dict[str, ViewRequirements])
|
||||
|
@ -72,10 +74,15 @@ class _AgentCollector:
|
|||
single_row={
|
||||
SampleBatch.OBS: init_obs,
|
||||
SampleBatch.EPS_ID: episode_id,
|
||||
SampleBatch.AGENT_INDEX: agent_id,
|
||||
SampleBatch.AGENT_INDEX: agent_index,
|
||||
"env_id": env_id,
|
||||
"t": t,
|
||||
})
|
||||
self.buffers[SampleBatch.OBS].append(init_obs)
|
||||
self.buffers[SampleBatch.EPS_ID].append(episode_id)
|
||||
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
|
||||
self.buffers["env_id"].append(env_id)
|
||||
self.buffers["t"].append(t)
|
||||
|
||||
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
|
||||
None:
|
||||
|
@ -133,7 +140,7 @@ class _AgentCollector:
|
|||
continue
|
||||
# OBS are already shifted by -1 (the initial obs starts one ts
|
||||
# before all other data columns).
|
||||
shift = view_req.shift - \
|
||||
shift = view_req.data_rel_pos - \
|
||||
(1 if data_col == SampleBatch.OBS else 0)
|
||||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
|
@ -187,7 +194,10 @@ class _AgentCollector:
|
|||
for col, data in single_row.items():
|
||||
if col in self.buffers:
|
||||
continue
|
||||
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
|
||||
shift = self.shift_before - (1 if col in [
|
||||
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||
"env_id", "t"
|
||||
] else 0)
|
||||
# Python primitive or dict (e.g. INFOs).
|
||||
if isinstance(data, (int, float, bool, str, dict)):
|
||||
self.buffers[col] = [0 for _ in range(shift)]
|
||||
|
@ -360,7 +370,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
@override(_SampleCollector)
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
env_id: EnvID, policy_id: PolicyID,
|
||||
env_id: EnvID, policy_id: PolicyID, t: int,
|
||||
init_obs: TensorType) -> None:
|
||||
# Make sure our mappings are up to date.
|
||||
agent_key = (episode.episode_id, agent_id)
|
||||
|
@ -378,8 +388,9 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.agent_collectors[agent_key] = _AgentCollector()
|
||||
self.agent_collectors[agent_key].add_init_obs(
|
||||
episode_id=episode.episode_id,
|
||||
agent_id=agent_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
env_id=env_id,
|
||||
t=t,
|
||||
init_obs=init_obs,
|
||||
view_requirements=view_reqs)
|
||||
|
||||
|
@ -429,7 +440,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
time_indices = \
|
||||
view_req.shift - (
|
||||
view_req.data_rel_pos - (
|
||||
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
||||
SampleBatch.EPS_ID,
|
||||
SampleBatch.AGENT_INDEX] else 0)
|
||||
|
|
|
@ -272,9 +272,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
output_creator (Callable[[IOContext], OutputWriter]): Function that
|
||||
returns an OutputWriter object for saving generated
|
||||
experiences.
|
||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
remote_worker_envs (bool): If using num_envs_per_worker > 1,
|
||||
whether to create those new envs in remote processes instead of
|
||||
in the current process. This adds overheads, but can make sense
|
||||
if your envs are expensive to step/reset (e.g., for StarCraft).
|
||||
Use this cautiously, overheads are significant!
|
||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||
are waiting when polling environments. 0 (continue when at
|
||||
least one env is ready) is a reasonable default, but optimal
|
||||
|
|
|
@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
|
|||
# Record transition info if applicable.
|
||||
if last_observation is None:
|
||||
_sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||
policy_id, filtered_obs)
|
||||
policy_id, episode.length - 1,
|
||||
filtered_obs)
|
||||
else:
|
||||
# Add actions, rewards, next-obs to collectors.
|
||||
values_dict = {
|
||||
|
@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(
|
|||
|
||||
# Add initial obs to buffer.
|
||||
_sample_collector.add_init_obs(
|
||||
new_episode, agent_id, env_id, policy_id, filtered_obs)
|
||||
new_episode, agent_id, env_id, policy_id,
|
||||
new_episode.length - 1, filtered_obs)
|
||||
|
||||
item = PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
|
|
|
@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
assert view_req_policy[key].data_rel_pos == 1
|
||||
rollout_worker = trainer.workers.local_worker()
|
||||
sample_batch = rollout_worker.sample()
|
||||
expected_count = \
|
||||
|
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
|
||||
if key == SampleBatch.PREV_ACTIONS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
||||
assert view_req_policy[key].shift == -1
|
||||
assert view_req_policy[key].data_rel_pos == -1
|
||||
elif key == SampleBatch.PREV_REWARDS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
||||
assert view_req_policy[key].shift == -1
|
||||
assert view_req_policy[key].data_rel_pos == -1
|
||||
elif key not in [
|
||||
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS
|
||||
|
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].shift == 1
|
||||
assert view_req_policy[key].data_rel_pos == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_simple_performance(self):
|
||||
|
|
|
@ -28,14 +28,16 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
"t": ViewRequirement(),
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
||||
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
SampleBatch.REWARDS, data_rel_pos=-1),
|
||||
}
|
||||
for i in range(2):
|
||||
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i), shift=-1, space=self.state_space)
|
||||
"state_out_{}".format(i),
|
||||
data_rel_pos=-1,
|
||||
space=self.state_space)
|
||||
self.model.inference_view_requirements[
|
||||
"state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=self.state_space)
|
||||
|
@ -43,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
self.view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
SampleBatch.OBS, shift=1),
|
||||
SampleBatch.OBS, data_rel_pos=1),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
|
|
|
@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
|
|||
self.view_requirements.update({
|
||||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
shift=-1,
|
||||
data_rel_pos=-1,
|
||||
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
||||
})
|
||||
|
||||
|
|
|
@ -61,7 +61,8 @@ class ModelV2:
|
|||
self.time_major = self.model_config.get("_time_major")
|
||||
# Basic view requirement for all models: Use the observation as input.
|
||||
self.inference_view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
|
||||
SampleBatch.OBS: ViewRequirement(
|
||||
data_rel_pos=0, space=self.obs_space),
|
||||
}
|
||||
|
||||
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
||||
|
|
|
@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
data_rel_pos=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
|
|
@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
data_rel_pos=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
|
|
@ -564,13 +564,14 @@ class Policy(metaclass=ABCMeta):
|
|||
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift=1,
|
||||
data_rel_pos=1,
|
||||
space=self.observation_space),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
SampleBatch.INFOS: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
|
@ -617,7 +618,7 @@ class Policy(metaclass=ABCMeta):
|
|||
batch_for_postproc.count = self._dummy_batch.count
|
||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||
if state_outs:
|
||||
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size]
|
||||
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
||||
# TODO: (sven) This hack will not work for attention net traj.
|
||||
# view setup.
|
||||
i = 0
|
||||
|
@ -657,7 +658,8 @@ class Policy(metaclass=ABCMeta):
|
|||
# Tag those only needed for post-processing.
|
||||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys and \
|
||||
key in self.view_requirements:
|
||||
key in self.view_requirements and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
# Remove those not needed at all (leave those that are needed
|
||||
# by Sampler to properly execute sample collection).
|
||||
|
@ -680,18 +682,6 @@ class Policy(metaclass=ABCMeta):
|
|||
"postprocessing function.".format(key))
|
||||
else:
|
||||
del self.view_requirements[key]
|
||||
# Add those data_cols (again) that are missing and have
|
||||
# dependencies by view_cols.
|
||||
for key in list(self.view_requirements.keys()):
|
||||
vr = self.view_requirements[key]
|
||||
if vr.data_col is not None and \
|
||||
vr.data_col not in self.view_requirements:
|
||||
used_for_training = \
|
||||
vr.data_col in train_batch.accessed_keys
|
||||
self.view_requirements[vr.data_col] = \
|
||||
ViewRequirement(
|
||||
space=vr.space,
|
||||
used_for_training=used_for_training)
|
||||
|
||||
def _get_dummy_batch_from_view_requirements(
|
||||
self, batch_size: int = 1) -> SampleBatch:
|
||||
|
@ -727,7 +717,7 @@ class Policy(metaclass=ABCMeta):
|
|||
model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
data_rel_pos=-1,
|
||||
space=Box(-1.0, 1.0, shape=state.shape))
|
||||
model.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
||||
|
|
|
@ -29,7 +29,7 @@ class ViewRequirement:
|
|||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
space: gym.Space = None,
|
||||
shift: Union[int, List[int]] = 0,
|
||||
data_rel_pos: Union[int, List[int]] = 0,
|
||||
used_for_training: bool = True):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
|
@ -40,13 +40,14 @@ class ViewRequirement:
|
|||
space (gym.Space): The gym Space used in case we need to pad data
|
||||
in inaccessible areas of the trajectory (t<0 or t>H).
|
||||
Default: Simple box space, e.g. rewards.
|
||||
shift (Union[int, List[int]]): Single shift value of list of
|
||||
shift values to use relative to the underlying `data_col`.
|
||||
data_rel_pos (Union[int, str, List[int]]): Single shift value or
|
||||
list of relative positions to use (relative to the underlying
|
||||
`data_col`).
|
||||
Example: For a view column "prev_actions", you can set
|
||||
`data_col="actions"` and `shift=-1`.
|
||||
`data_col="actions"` and `data_rel_pos=-1`.
|
||||
Example: For a view column "obs" in an Atari framestacking
|
||||
fashion, you can set `data_col="obs"` and
|
||||
`shift=[-3, -2, -1, 0]`.
|
||||
`data_rel_pos=[-3, -2, -1, 0]`.
|
||||
used_for_training (bool): Whether the data will be used for
|
||||
training. If False, the column will not be copied into the
|
||||
final train batch.
|
||||
|
@ -54,5 +55,5 @@ class ViewRequirement:
|
|||
self.data_col = data_col
|
||||
self.space = space or gym.spaces.Box(
|
||||
float("-inf"), float("inf"), shape=())
|
||||
self.shift = shift
|
||||
self.data_rel_pos = data_rel_pos
|
||||
self.used_for_training = used_for_training
|
||||
|
|
Loading…
Add table
Reference in a new issue