mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Trajectory view API: Simple List Collector (on by default for PPO); LSTM-agnostic (#11056)
This commit is contained in:
parent
0d93b1de93
commit
36bda8432b
40 changed files with 1154 additions and 1173 deletions
|
@ -37,6 +37,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Workers sample async. Note that this increases the effective
|
||||
# rollout_fragment_length by up to 5x due to async buffering of batches.
|
||||
"sample_async": True,
|
||||
# Switch on Trajectory View API for A2/3C by default.
|
||||
# NOTE: Only supported for PyTorch so far.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import gym
|
||||
from typing import Dict
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -82,6 +87,42 @@ class ValueNetworkMixin:
|
|||
return self.model.value_function()[0]
|
||||
|
||||
|
||||
def view_requirements_fn(policy: Policy) -> Dict[str, ViewRequirement]:
|
||||
"""Function defining the view requirements for training/postprocessing.
|
||||
|
||||
These go on top of the Policy's Model's own view requirements used for
|
||||
the action computing forward passes.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy that requires the returned
|
||||
ViewRequirements.
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The Policy's view requirements.
|
||||
"""
|
||||
ret = {
|
||||
# Next obs are needed for PPO postprocessing, but not in loss.
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
SampleBatch.OBS, shift=1, used_for_training=False),
|
||||
# Created during postprocessing.
|
||||
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
|
||||
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
|
||||
# Needed for PPO's loss function.
|
||||
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
|
||||
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
|
||||
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
|
||||
}
|
||||
# If policy is recurrent, have to add state_out for PPO postprocessing
|
||||
# (calculating GAE from next-obs and last state-out).
|
||||
if policy.is_recurrent():
|
||||
init_state = policy.get_initial_state()
|
||||
for i, s in enumerate(init_state):
|
||||
ret["state_out_{}".format(i)] = ViewRequirement(
|
||||
space=gym.spaces.Box(-1.0, 1.0, shape=(s.shape[0], )),
|
||||
used_for_training=False)
|
||||
return ret
|
||||
|
||||
|
||||
A3CTorchPolicy = build_torch_policy(
|
||||
name="A3CTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||
|
@ -91,4 +132,6 @@ A3CTorchPolicy = build_torch_policy(
|
|||
extra_action_out_fn=model_value_predictions,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=torch_optimizer,
|
||||
mixins=[ValueNetworkMixin])
|
||||
mixins=[ValueNetworkMixin],
|
||||
view_requirements_fn=view_requirements_fn,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
|
|||
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
|
@ -17,6 +18,8 @@ from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
|
@ -27,6 +30,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 0,
|
||||
# Learning rate.
|
||||
"lr": 0.0004,
|
||||
# Switch on Trajectory View API for PG by default.
|
||||
# NOTE: Only supported for PyTorch so far.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
|
|
@ -5,6 +5,7 @@ PyTorch policy class used for PG.
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import view_requirements_fn
|
||||
from ray.rllib.agents.pg.utils import post_process_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
|
@ -77,4 +78,6 @@ PGTorchPolicy = build_torch_policy(
|
|||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
||||
loss_fn=pg_torch_loss,
|
||||
stats_fn=pg_loss_stats,
|
||||
postprocess_fn=post_process_advantages)
|
||||
postprocess_fn=post_process_advantages,
|
||||
view_requirements_fn=view_requirements_fn,
|
||||
)
|
||||
|
|
|
@ -72,6 +72,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
|||
"truncate_episodes": True,
|
||||
# This is auto set based on sample batch size.
|
||||
"train_batch_size": -1,
|
||||
# Trajectory View API not supported yet for DD-PPO.
|
||||
"_use_trajectory_view_api": False,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
|
|
@ -89,6 +89,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Whether to fake GPUs (using CPUs).
|
||||
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
|
||||
"_fake_gpus": False,
|
||||
# Switch on Trajectory View API for PPO by default.
|
||||
# NOTE: Only supported for PyTorch so far.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
@ -126,7 +129,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]:
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
"function (to estimate the return at the end of the truncated "
|
||||
"trajectory). Consider setting batch_mode=complete_episodes.")
|
||||
|
||||
# Multi-gpu not supported for PyTorch and tf-eager.
|
||||
if config["framework"] in ["tf2", "tfe", "torch"]:
|
||||
|
|
|
@ -7,7 +7,8 @@ import numpy as np
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping, \
|
||||
view_requirements_fn
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
|
@ -18,7 +19,6 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
|
||||
explained_variance, sequence_mask
|
||||
|
@ -255,34 +255,6 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def training_view_requirements_fn(
|
||||
policy: Policy) -> Dict[str, ViewRequirement]:
|
||||
"""Function defining the view requirements for training the policy.
|
||||
|
||||
These go on top of the Policy's Model's own view requirements used for
|
||||
action computing forward passes.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy that requires the returned
|
||||
ViewRequirements.
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The Policy's view requirements.
|
||||
"""
|
||||
return {
|
||||
# Next obs are needed for PPO postprocessing.
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
|
||||
# VF preds are needed for the loss.
|
||||
SampleBatch.VF_PREDS: ViewRequirement(shift=0),
|
||||
# Needed for postprocessing.
|
||||
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0),
|
||||
SampleBatch.ACTION_LOGP: ViewRequirement(shift=0),
|
||||
# Created during postprocessing.
|
||||
Postprocessing.ADVANTAGES: ViewRequirement(shift=0),
|
||||
Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
PPOTorchPolicy = build_torch_policy(
|
||||
|
@ -299,5 +271,5 @@ PPOTorchPolicy = build_torch_policy(
|
|||
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
||||
ValueNetworkMixin
|
||||
],
|
||||
training_view_requirements_fn=training_view_requirements_fn,
|
||||
view_requirements_fn=view_requirements_fn,
|
||||
)
|
||||
|
|
|
@ -1047,9 +1047,10 @@ class Trainer(Trainable):
|
|||
def _validate_config(config: PartialTrainerConfigDict):
|
||||
if config.get("_use_trajectory_view_api") and \
|
||||
config.get("framework") != "torch":
|
||||
raise ValueError(
|
||||
logger.info(
|
||||
"`_use_trajectory_view_api` only supported for PyTorch so "
|
||||
"far!")
|
||||
"far! Will run w/o.")
|
||||
config["_use_trajectory_view_api"] = False
|
||||
elif not config.get("_use_trajectory_view_api") and \
|
||||
config.get("model", {}).get("_time_major"):
|
||||
raise ValueError("`model._time_major` only supported "
|
||||
|
|
3
rllib/env/base_env.py
vendored
3
rllib/env/base_env.py
vendored
|
@ -183,7 +183,8 @@ class BaseEnv:
|
|||
reset the entire Env (i.e. all sub-envs).
|
||||
|
||||
Returns:
|
||||
obs (dict|None): Resetted observation or None if not supported.
|
||||
Optional[MultiAgentDict]: Resetted (multi-agent) observation dict
|
||||
or None if reset is not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
|
0
rllib/evaluation/collectors/__init__.py
Normal file
0
rllib/evaluation/collectors/__init__.py
Normal file
|
@ -1,9 +1,10 @@
|
|||
from abc import abstractmethod, ABCMeta
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Union
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, PolicyID, \
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -29,7 +30,7 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
policy_id: PolicyID, init_obs: TensorType) -> None:
|
||||
"""Adds an initial obs (after reset) to this collector.
|
||||
|
||||
|
@ -41,10 +42,11 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
called for that same agent/episode-pair.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique id for the episode we are adding
|
||||
values for.
|
||||
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we
|
||||
are adding an Agent's initial observation.
|
||||
agent_id (AgentID): Unique id for the agent we are adding
|
||||
values for.
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||
init_obs (TensorType): Initial observation (after env.reset()).
|
||||
|
||||
|
@ -52,7 +54,7 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
>>> obs = env.reset()
|
||||
>>> collector.add_init_obs(12345, 0, "pol0", obs)
|
||||
>>> obs, r, done, info = env.step(action)
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
||||
... "action": action, "obs": obs, "reward": r, "done": done
|
||||
... })
|
||||
"""
|
||||
|
@ -60,7 +62,8 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
||||
agent_id: AgentID, policy_id: PolicyID,
|
||||
agent_id: AgentID, env_id: EnvID,
|
||||
policy_id: PolicyID, agent_done: bool,
|
||||
values: Dict[str, TensorType]) -> None:
|
||||
"""Add the given dictionary (row) of values to this collector.
|
||||
|
||||
|
@ -74,7 +77,10 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
values for.
|
||||
agent_id (AgentID): Unique id for the agent we are adding
|
||||
values for.
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||
agent_done (bool): Whether the given agent is done with its
|
||||
trajectory (the multi-agent episode may still be ongoing).
|
||||
values (Dict[str, TensorType]): Row of values to add for this
|
||||
agent. This row must contain the keys SampleBatch.ACTION,
|
||||
REWARD, NEW_OBS, and DONE.
|
||||
|
@ -83,12 +89,22 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
>>> obs = env.reset()
|
||||
>>> collector.add_init_obs(12345, 0, "pol0", obs)
|
||||
>>> obs, r, done, info = env.step(action)
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
||||
... "action": action, "obs": obs, "reward": r, "done": done
|
||||
... })
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def episode_step(self, episode_id: EpisodeID) -> None:
|
||||
"""Increases the episode step counter (across all agents) by one.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique id for the episode we are stepping
|
||||
through (across all agents in that episode).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def total_env_steps(self) -> int:
|
||||
"""Returns total number of steps taken in the env (sum of all agents).
|
||||
|
@ -126,19 +142,11 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_non_postprocessed_data(self) -> bool:
|
||||
"""Returns whether there is pending, unprocessed data.
|
||||
|
||||
Returns:
|
||||
bool: True if there is at least some data that has not been
|
||||
postprocessed yet.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def postprocess_trajectories_so_far(
|
||||
self, episode: Optional[MultiAgentEpisode] = None) -> None:
|
||||
"""Apply postprocessing to unprocessed data (in one or all episodes).
|
||||
def postprocess_episode(self,
|
||||
episode: MultiAgentEpisode,
|
||||
is_done: bool = False,
|
||||
check_dones: bool = False) -> None:
|
||||
"""Postprocesses all agents' trajectories in a given episode.
|
||||
|
||||
Generates (single-trajectory) SampleBatches for all Policies/Agents and
|
||||
calls Policy.postprocess_trajectory on each of these. Postprocessing
|
||||
|
@ -148,38 +156,46 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
correctly added to the buffers.
|
||||
|
||||
Args:
|
||||
episode (Optional[MultiAgentEpisode]): The Episode object for which
|
||||
to post-process data. If not provided, postprocess data for all
|
||||
episodes.
|
||||
episode (MultiAgentEpisode): The Episode object for which
|
||||
to post-process data.
|
||||
is_done (bool): Whether the given episode is actually terminated
|
||||
(all agents are done).
|
||||
check_dones (bool): Whether we need to check that all agents'
|
||||
trajectories have dones=True at the end.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check_missing_dones(self, episode_id: EpisodeID) -> None:
|
||||
"""Checks whether given episode is properly terminated with done=True.
|
||||
|
||||
This applies to all agents in the episode.
|
||||
def build_multi_agent_batch(self, env_steps: int) -> \
|
||||
Union[MultiAgentBatch, SampleBatch]:
|
||||
"""Builds a MultiAgentBatch of size=env_steps from the collected data.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): The episode ID to check for proper
|
||||
termination.
|
||||
env_steps (int): The sum of all env-steps (across all agents) taken
|
||||
so far.
|
||||
|
||||
Raises:
|
||||
ValueError: If `episode` has no done=True at the end.
|
||||
Returns:
|
||||
Union[MultiAgentBatch, SampleBatch]: Returns the accumulated
|
||||
sample batches for each policy inside one MultiAgentBatch
|
||||
object (or a simple SampleBatch if only one policy).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_agent_batch_and_reset(self):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
def try_build_truncated_episode_multi_agent_batch(self) -> \
|
||||
Union[MultiAgentBatch, SampleBatch, None]:
|
||||
"""Tries to build an MA-batch, if `rollout_fragment_length` is reached.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with a policy
|
||||
postprocessor. The internal state of this builder will be reset to
|
||||
start the next batch.
|
||||
Any unprocessed data will be first postprocessed with a policy
|
||||
postprocessor.
|
||||
This is usually called to collect samples for policy training.
|
||||
If not enough data has been collected yet (`rollout_fragment_length`),
|
||||
returns None.
|
||||
|
||||
Returns:
|
||||
MultiAgentBatch: Returns the accumulated sample batches for each
|
||||
policy inside one MultiAgentBatch object.
|
||||
Union[MultiAgentBatch, SampleBatch, None]: Returns the accumulated
|
||||
sample batches for each policy inside one MultiAgentBatch
|
||||
object (or a simple SampleBatch if only one policy) or None
|
||||
if `self.rollout_fragment_length` has not been reached yet.
|
||||
"""
|
||||
raise NotImplementedError
|
589
rllib/evaluation/collectors/simple_list_collector.py
Normal file
589
rllib/evaluation/collectors/simple_list_collector.py
Normal file
|
@ -0,0 +1,589 @@
|
|||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Any, Dict, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
|
||||
TensorType
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
|
||||
from ray.util.debug import log_once
|
||||
|
||||
_, tf, _ = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_float_np_array(v: List[Any]) -> np.ndarray:
|
||||
if torch.is_tensor(v[0]):
|
||||
raise ValueError
|
||||
v = convert_to_non_torch_type(v)
|
||||
arr = np.array(v)
|
||||
if arr.dtype == np.float64:
|
||||
return arr.astype(np.float32) # save some memory
|
||||
return arr
|
||||
|
||||
|
||||
class _AgentCollector:
|
||||
"""Collects samples for one agent in one trajectory (episode).
|
||||
|
||||
The agent may be part of a multi-agent environment. Samples are stored in
|
||||
lists including some possible automatic "shift" buffer at the beginning to
|
||||
be able to save memory when storing things like NEXT_OBS, PREV_REWARDS,
|
||||
etc.., which are specified using the trajectory view API.
|
||||
"""
|
||||
|
||||
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
||||
|
||||
def __init__(self, shift_before: int = 0):
|
||||
self.shift_before = max(shift_before, 1)
|
||||
self.buffers: Dict[str, List] = {}
|
||||
# The simple timestep count for this agent. Gets increased by one
|
||||
# each time a (non-initial!) observation is added.
|
||||
self.count = 0
|
||||
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
||||
env_id: EnvID, init_obs: TensorType,
|
||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
||||
initial observation for.
|
||||
agent_id (AgentID): Unique ID for the agent we are adding the
|
||||
initial observation for.
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
init_obs (TensorType): The initial observation tensor (after
|
||||
`env.reset()`).
|
||||
view_requirements (Dict[str, ViewRequirements])
|
||||
"""
|
||||
if SampleBatch.OBS not in self.buffers:
|
||||
self._build_buffers(
|
||||
single_row={
|
||||
SampleBatch.OBS: init_obs,
|
||||
SampleBatch.EPS_ID: episode_id,
|
||||
SampleBatch.AGENT_INDEX: agent_id,
|
||||
"env_id": env_id,
|
||||
})
|
||||
self.buffers[SampleBatch.OBS].append(init_obs)
|
||||
|
||||
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
|
||||
None:
|
||||
"""Adds the given dictionary (row) of values to the Agent's trajectory.
|
||||
|
||||
Args:
|
||||
values (Dict[str, TensorType]): Data dict (interpreted as a single
|
||||
row) to be added to buffer. Must contain keys:
|
||||
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
|
||||
"""
|
||||
|
||||
assert SampleBatch.OBS not in values
|
||||
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
|
||||
del values[SampleBatch.NEXT_OBS]
|
||||
|
||||
for k, v in values.items():
|
||||
if k not in self.buffers:
|
||||
self._build_buffers(single_row=values)
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
|
||||
SampleBatch:
|
||||
"""Builds a SampleBatch from the thus-far collected agent data.
|
||||
|
||||
If the episode/trajectory has no DONE=True at the end, will copy
|
||||
the necessary n timesteps at the end of the trajectory back to the
|
||||
beginning of the buffers and wait for new samples coming in.
|
||||
SampleBatches created by this method will be ready for postprocessing
|
||||
by a Policy.
|
||||
|
||||
Args:
|
||||
view_requirements (Dict[str, ViewRequirement]: The view
|
||||
requirements dict needed to build the SampleBatch from the raw
|
||||
buffers (which may have data shifts as well as mappings from
|
||||
view-col to data-col in them).
|
||||
Returns:
|
||||
SampleBatch: The built SampleBatch for this agent, ready to go into
|
||||
postprocessing.
|
||||
"""
|
||||
|
||||
# TODO: measure performance gains when using a UsageTrackingDict
|
||||
# instead of a SampleBatch for postprocessing (this would eliminate
|
||||
# copies (for creating this SampleBatch) of many unused columns for
|
||||
# no reason (not used by postprocessor)).
|
||||
|
||||
batch_data = {}
|
||||
np_data = {}
|
||||
for view_col, view_req in view_requirements.items():
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
# Some columns don't exist yet (get created during postprocessing).
|
||||
# -> skip.
|
||||
if data_col not in self.buffers:
|
||||
continue
|
||||
shift = view_req.shift - \
|
||||
(1 if data_col == SampleBatch.OBS else 0)
|
||||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
if shift == 0:
|
||||
batch_data[view_col] = np_data[data_col][self.shift_before:]
|
||||
else:
|
||||
batch_data[view_col] = np_data[data_col][self.shift_before +
|
||||
shift:shift]
|
||||
batch = SampleBatch(batch_data)
|
||||
|
||||
if SampleBatch.UNROLL_ID not in batch.data:
|
||||
batch.data[SampleBatch.UNROLL_ID] = np.repeat(
|
||||
_AgentCollector._next_unroll_id, batch.count)
|
||||
_AgentCollector._next_unroll_id += 1
|
||||
|
||||
# This trajectory is continuing -> Copy data at the end (in the size of
|
||||
# self.shift_before) to the beginning of buffers and erase everything
|
||||
# else.
|
||||
if not self.buffers[SampleBatch.DONES][-1]:
|
||||
# Copy data to beginning of buffer and cut lists.
|
||||
if self.shift_before > 0:
|
||||
for k, data in self.buffers.items():
|
||||
self.buffers[k] = data[-self.shift_before:]
|
||||
self.count = 0
|
||||
|
||||
return batch
|
||||
|
||||
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
|
||||
"""Builds the buffers for sample collection, given an example data row.
|
||||
|
||||
Args:
|
||||
single_row (Dict[str, TensorType]): A single row (keys=column
|
||||
names) of data to base the buffers on.
|
||||
"""
|
||||
for col, data in single_row.items():
|
||||
if col in self.buffers:
|
||||
continue
|
||||
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
|
||||
# Python primitive.
|
||||
if isinstance(data, (int, float, bool, str)):
|
||||
self.buffers[col] = [0 for _ in range(shift)]
|
||||
# np.ndarray, torch.Tensor, or tf.Tensor.
|
||||
else:
|
||||
shape = data.shape
|
||||
dtype = data.dtype
|
||||
if torch and isinstance(data, torch.Tensor):
|
||||
self.buffers[col] = \
|
||||
[torch.zeros(shape, dtype=dtype, device=data.device)
|
||||
for _ in range(shift)]
|
||||
elif tf and isinstance(data, tf.Tensor):
|
||||
self.buffers[col] = \
|
||||
[tf.zeros(shape=shape, dtype=dtype)
|
||||
for _ in range(shift)]
|
||||
else:
|
||||
self.buffers[col] = \
|
||||
[np.zeros(shape=shape, dtype=dtype)
|
||||
for _ in range(shift)]
|
||||
|
||||
|
||||
class _PolicyCollector:
|
||||
"""Collects already postprocessed (single agent) samples for one policy.
|
||||
|
||||
Samples come in through already postprocessed SampleBatches, which
|
||||
contain single episode/trajectory data for a single agent and are then
|
||||
appended to this policy's buffers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a _PolicyCollector instance."""
|
||||
|
||||
self.buffers: Dict[str, List] = collections.defaultdict(list)
|
||||
# The total timestep count for all agents that use this policy.
|
||||
# NOTE: This is not an env-step count (across n agents). AgentA and
|
||||
# agentB, both using this policy, acting in the same episode and both
|
||||
# doing n steps would increase the count by 2*n.
|
||||
self.count = 0
|
||||
|
||||
def add_postprocessed_batch_for_training(
|
||||
self, batch: SampleBatch,
|
||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): A single agent (one trajectory) SampleBatch
|
||||
to be added to the Policy's buffers.
|
||||
view_requirements (Dict[str, ViewRequirement]: The view
|
||||
requirements for the policy. This is so we know, whether a
|
||||
view-column needs to be copied at all (not needed for
|
||||
training).
|
||||
"""
|
||||
for view_col, data in batch.items():
|
||||
# Skip columns that are not used for training.
|
||||
if view_col in view_requirements and \
|
||||
not view_requirements[view_col].used_for_training:
|
||||
continue
|
||||
self.buffers[view_col].extend(data)
|
||||
# Add the agent's trajectory length to our count.
|
||||
self.count += batch.count
|
||||
|
||||
def build(self):
|
||||
"""Builds a SampleBatch for this policy from the collected data.
|
||||
|
||||
Also resets all buffers for further sample collection for this policy.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The SampleBatch with all thus-far collected data for
|
||||
this policy.
|
||||
"""
|
||||
# Create batch from our buffers.
|
||||
batch = SampleBatch(self.buffers)
|
||||
assert SampleBatch.UNROLL_ID in batch.data
|
||||
# Clear buffers for future samples.
|
||||
self.buffers.clear()
|
||||
# Reset count to 0.
|
||||
self.count = 0
|
||||
return batch
|
||||
|
||||
|
||||
class _SimpleListCollector(_SampleCollector):
|
||||
"""Util to build SampleBatches for each policy in a multi-agent env.
|
||||
|
||||
Input data is per-agent, while output data is per-policy. There is an M:N
|
||||
mapping between agents and policies. We retain one local batch builder
|
||||
per agent. When an agent is done, then its local batch is appended into the
|
||||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
clip_rewards: Union[bool, float],
|
||||
callbacks: "DefaultCallbacks",
|
||||
multiple_episodes_in_batch: bool = True,
|
||||
rollout_fragment_length: int = 200):
|
||||
"""Initializes a _SimpleListCollector instance.
|
||||
|
||||
Args:
|
||||
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
||||
instances.
|
||||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
self.callbacks = callbacks
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.large_batch_threshold: int = max(
|
||||
1000, rollout_fragment_length *
|
||||
10) if rollout_fragment_length != float("inf") else 5000
|
||||
|
||||
# Build each Policies' single collector.
|
||||
self.policy_collectors = {
|
||||
pid: _PolicyCollector()
|
||||
for pid in policy_map.keys()
|
||||
}
|
||||
self.policy_collectors_env_steps = 0
|
||||
# Whenever we observe a new episode+agent, add a new
|
||||
# _SingleTrajectoryCollector.
|
||||
self.agent_collectors: Dict[Tuple[EpisodeID, AgentID],
|
||||
_AgentCollector] = {}
|
||||
# Internal agent-key-to-policy map.
|
||||
self.agent_key_to_policy = {}
|
||||
|
||||
# Agents to collect data from for the next forward pass (per policy).
|
||||
self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()}
|
||||
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
|
||||
|
||||
# Maps episode ID to _EpisodeRecord objects.
|
||||
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
||||
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
|
||||
|
||||
@override(_SampleCollector)
|
||||
def episode_step(self, episode_id: EpisodeID) -> None:
|
||||
self.episode_steps[episode_id] += 1
|
||||
|
||||
env_steps = \
|
||||
self.policy_collectors_env_steps + self.episode_steps[episode_id]
|
||||
if (env_steps > self.large_batch_threshold
|
||||
and log_once("large_batch_warning")):
|
||||
logger.warning(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
env_steps, env_steps) +
|
||||
"are buffered in the sampler. If this is more than you "
|
||||
"expected, check that that you set a horizon on your "
|
||||
"environment correctly and that it terminates at some point. "
|
||||
"Note: In multi-agent environments, `rollout_fragment_length` "
|
||||
"sets the batch size based on (across-agents) environment "
|
||||
"steps, not the steps of individual agents, which can result "
|
||||
"in unexpectedly large batches." +
|
||||
("Also, you may be in evaluation waiting for your Env to "
|
||||
"terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||
"does at some point."
|
||||
if not self.multiple_episodes_in_batch else ""))
|
||||
|
||||
@override(_SampleCollector)
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
env_id: EnvID, policy_id: PolicyID,
|
||||
init_obs: TensorType) -> None:
|
||||
# Make sure our mappings are up to date.
|
||||
agent_key = (episode.episode_id, agent_id)
|
||||
if agent_key not in self.agent_key_to_policy:
|
||||
self.agent_key_to_policy[agent_key] = policy_id
|
||||
else:
|
||||
assert self.agent_key_to_policy[agent_key] == policy_id
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
hasattr(policy, "model") else policy.view_requirements
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
assert agent_key not in self.agent_collectors
|
||||
# TODO: determine exact shift-before based on the view-req shifts.
|
||||
self.agent_collectors[agent_key] = _AgentCollector()
|
||||
self.agent_collectors[agent_key].add_init_obs(
|
||||
episode_id=episode.episode_id,
|
||||
agent_id=agent_id,
|
||||
env_id=env_id,
|
||||
init_obs=init_obs,
|
||||
view_requirements=view_reqs)
|
||||
|
||||
self.episodes[episode.episode_id] = episode
|
||||
|
||||
self._add_to_next_inference_call(agent_key, env_id)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
||||
agent_id: AgentID, env_id: EnvID,
|
||||
policy_id: PolicyID, agent_done: bool,
|
||||
values: Dict[str, TensorType]) -> None:
|
||||
# Make sure, episode/agent already has some (at least init) data.
|
||||
agent_key = (episode_id, agent_id)
|
||||
assert self.agent_key_to_policy[agent_key] == policy_id
|
||||
assert agent_key in self.agent_collectors
|
||||
|
||||
# Include the current agent id for multi-agent algorithms.
|
||||
if agent_id != _DUMMY_AGENT_ID:
|
||||
values["agent_id"] = agent_id
|
||||
|
||||
# Add action/reward/next-obs (and other data) to Trajectory.
|
||||
self.agent_collectors[agent_key].add_action_reward_next_obs(values)
|
||||
|
||||
if not agent_done:
|
||||
self._add_to_next_inference_call(agent_key, env_id)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def total_env_steps(self) -> int:
|
||||
return sum(a.count for a in self.agent_collectors.values())
|
||||
|
||||
@override(_SampleCollector)
|
||||
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
||||
Dict[str, TensorType]:
|
||||
policy = self.policy_map[policy_id]
|
||||
keys = self.forward_pass_agent_keys[policy_id]
|
||||
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
hasattr(policy, "model") else policy.view_requirements
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in view_reqs.items():
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
time_indices = \
|
||||
view_req.shift - (
|
||||
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
||||
SampleBatch.EPS_ID,
|
||||
SampleBatch.AGENT_INDEX] else 0)
|
||||
data_list = []
|
||||
for k in keys:
|
||||
if data_col not in buffers[k]:
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: view_req.space.sample()
|
||||
})
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
input_dict[view_col] = np.array(data_list)
|
||||
|
||||
self._reset_inference_calls(policy_id)
|
||||
|
||||
return input_dict
|
||||
|
||||
@override(_SampleCollector)
|
||||
def postprocess_episode(self,
|
||||
episode: MultiAgentEpisode,
|
||||
is_done: bool = False,
|
||||
check_dones: bool = False) -> None:
|
||||
episode_id = episode.episode_id
|
||||
|
||||
# TODO: (sven) Once we implement multi-agent communication channels,
|
||||
# we have to resolve the restriction of only sending other agent
|
||||
# batches from the same policy to the postprocess methods.
|
||||
# Build SampleBatches for the given episode.
|
||||
pre_batches = {}
|
||||
for (eps_id, agent_id), collector in self.agent_collectors.items():
|
||||
# Build only if there is data and agent is part of given episode.
|
||||
if collector.count == 0 or eps_id != episode_id:
|
||||
continue
|
||||
policy = self.policy_map[self.agent_key_to_policy[(eps_id,
|
||||
agent_id)]]
|
||||
pre_batch = collector.build(policy.view_requirements)
|
||||
pre_batches[agent_id] = (policy, pre_batch)
|
||||
|
||||
# Apply postprocessor.
|
||||
post_batches = {}
|
||||
if self.clip_rewards is True:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
||||
elif self.clip_rewards:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.clip(
|
||||
pre_batch["rewards"],
|
||||
a_min=-self.clip_rewards,
|
||||
a_max=self.clip_rewards)
|
||||
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
# Entire episode is said to be done.
|
||||
if is_done:
|
||||
# Error if no DONE at end of this agent's trajectory.
|
||||
if check_dones and not pre_batch[SampleBatch.DONES][-1]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we still "
|
||||
"don't have a last observation for agent {} (policy "
|
||||
"{}). ".format(
|
||||
episode_id, agent_id, self.agent_key_to_policy[(
|
||||
episode_id, agent_id)]) +
|
||||
"Please ensure that you include the last observations "
|
||||
"of all live agents when setting done[__all__] to "
|
||||
"True. Alternatively, set no_done_at_end=True to "
|
||||
"allow this.")
|
||||
# If (only this?) agent is done, erase its buffer entirely.
|
||||
if pre_batch[SampleBatch.DONES][-1]:
|
||||
del self.agent_collectors[(episode_id, agent_id)]
|
||||
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
policy = self.policy_map[self.agent_key_to_policy[(episode_id,
|
||||
agent_id)]]
|
||||
if any(pre_batch["dones"][:-1]) or len(set(
|
||||
pre_batch["eps_id"])) > 1:
|
||||
raise ValueError(
|
||||
"Batches sent to postprocessing must only contain steps "
|
||||
"from a single trajectory.", pre_batch)
|
||||
# Call the Policy's Exploration's postprocess method.
|
||||
post_batches[agent_id] = pre_batch
|
||||
if getattr(policy, "exploration", None) is not None:
|
||||
policy.exploration.postprocess_trajectory(
|
||||
policy, post_batches[agent_id],
|
||||
getattr(policy, "_sess", None))
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
post_batches[agent_id], other_batches, episode)
|
||||
|
||||
if log_once("after_post"):
|
||||
logger.info(
|
||||
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
|
||||
format(summarize(post_batches)))
|
||||
|
||||
# Append into policy batches and reset.
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
pid = self.agent_key_to_policy[(episode_id, agent_id)]
|
||||
policy = self.policy_map[pid]
|
||||
self.callbacks.on_postprocess_trajectory(
|
||||
worker=get_global_worker(),
|
||||
episode=episode,
|
||||
agent_id=agent_id,
|
||||
policy_id=pid,
|
||||
policies=self.policy_map,
|
||||
postprocessed_batch=post_batch,
|
||||
original_batches=pre_batches)
|
||||
# Add the postprocessed SampleBatch to the policy collectors for
|
||||
# training.
|
||||
self.policy_collectors[pid].add_postprocessed_batch_for_training(
|
||||
post_batch, policy.view_requirements)
|
||||
|
||||
env_steps = self.episode_steps[episode_id]
|
||||
self.policy_collectors_env_steps += env_steps
|
||||
|
||||
if is_done:
|
||||
del self.episode_steps[episode_id]
|
||||
del self.episodes[episode_id]
|
||||
else:
|
||||
self.episode_steps[episode_id] = 0
|
||||
|
||||
@override(_SampleCollector)
|
||||
def build_multi_agent_batch(self, env_steps: int) -> \
|
||||
Union[MultiAgentBatch, SampleBatch]:
|
||||
ma_batch = MultiAgentBatch.wrap_as_needed(
|
||||
{
|
||||
pid: collector.build()
|
||||
for pid, collector in self.policy_collectors.items()
|
||||
if collector.count > 0
|
||||
},
|
||||
env_steps=env_steps)
|
||||
self.policy_collectors_env_steps = 0
|
||||
return ma_batch
|
||||
|
||||
@override(_SampleCollector)
|
||||
def try_build_truncated_episode_multi_agent_batch(self) -> \
|
||||
Union[MultiAgentBatch, SampleBatch, None]:
|
||||
# Have something to loop through, even if there are currently no
|
||||
# ongoing episodes.
|
||||
episode_steps = self.episode_steps or {"_fake_id": 0}
|
||||
# Loop through ongoing episodes and see whether their length plus
|
||||
# what's already in the policy collectors reaches the fragment-len.
|
||||
for episode_id, count in episode_steps.items():
|
||||
env_steps = self.policy_collectors_env_steps + count
|
||||
# Reached the fragment-len -> We should build an MA-Batch.
|
||||
if env_steps >= self.rollout_fragment_length:
|
||||
# If we reached the fragment-len only because of `episode_id`
|
||||
# (still ongoing) -> postprocess `episode_id` first.
|
||||
if self.policy_collectors_env_steps < \
|
||||
self.rollout_fragment_length:
|
||||
self.postprocess_episode(
|
||||
self.episodes[episode_id], is_done=False)
|
||||
# Otherwise, create MA-batch only from what's already in our
|
||||
# policy buffers (do not include `episode_id`'s data).
|
||||
else:
|
||||
env_steps = self.policy_collectors_env_steps
|
||||
# Build the MA-batch and return.
|
||||
ma_batch = self.build_multi_agent_batch(env_steps=env_steps)
|
||||
return ma_batch
|
||||
return None
|
||||
|
||||
def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID],
|
||||
env_id: EnvID) -> None:
|
||||
"""Adds an Agent key (episode+agent IDs) to the next inference call.
|
||||
|
||||
This makes sure that the agent's current data (in the trajectory) is
|
||||
used for generating the next input_dict for a
|
||||
`Policy.compute_actions()` call.
|
||||
|
||||
Args:
|
||||
agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across
|
||||
vectorized environments).
|
||||
env_id (EnvID): The environment index (in a vectorized setup).
|
||||
"""
|
||||
policy_id = self.agent_key_to_policy[agent_key]
|
||||
idx = self.forward_pass_size[policy_id]
|
||||
if idx == 0:
|
||||
self.forward_pass_agent_keys[policy_id].clear()
|
||||
self.forward_pass_agent_keys[policy_id].append(agent_key)
|
||||
self.forward_pass_size[policy_id] += 1
|
||||
|
||||
def _reset_inference_calls(self, policy_id: PolicyID) -> None:
|
||||
"""Resets internal inference input-dict registries.
|
||||
|
||||
Calling `self.get_inference_input_dict()` after this method is called
|
||||
would return an empty input-dict.
|
||||
|
||||
Args:
|
||||
policy_id (PolicyID): The policy ID for which to reset the
|
||||
inference pointers.
|
||||
"""
|
||||
self.forward_pass_size[policy_id] = 0
|
|
@ -1,249 +0,0 @@
|
|||
import logging
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.per_policy_sample_collector import \
|
||||
_PerPolicySampleCollector
|
||||
from ray.rllib.evaluation.sample_collector import _SampleCollector
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
from ray.util.debug import log_once
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _MultiAgentSampleCollector(_SampleCollector):
|
||||
"""Builds SampleBatches for each policy (and agent) in a multi-agent env.
|
||||
|
||||
Note: This is an experimental class only used when
|
||||
`config._use_trajectory_view_api` = True.
|
||||
Once `_use_trajectory_view_api` becomes the default in configs:
|
||||
This class will deprecate the `SampleBatchBuilder` class.
|
||||
|
||||
Input data is collected in central per-policy buffers, which
|
||||
efficiently pre-allocate memory (over n timesteps) and re-use the same
|
||||
memory even for succeeding agents and episodes.
|
||||
Input_dicts for action computations, SampleBatches for postprocessing, and
|
||||
train_batch dicts are - if possible - created from the central per-policy
|
||||
buffers via views to avoid copying of data).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
callbacks: "DefaultCallbacks",
|
||||
# TODO: (sven) make `num_agents` flexibly grow in size.
|
||||
num_agents: int = 100,
|
||||
num_timesteps=None,
|
||||
time_major: Optional[bool] = False):
|
||||
"""Initializes a _MultiAgentSampleCollector object.
|
||||
|
||||
Args:
|
||||
policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy
|
||||
instances.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks (configured in the
|
||||
Trainer config dict). Used for trajectory postprocessing event.
|
||||
num_agents (int): The max number of agent slots to pre-allocate
|
||||
in the buffer.
|
||||
num_timesteps (int): The max number of timesteps to pre-allocate
|
||||
in the buffer.
|
||||
time_major (Optional[bool]): Whether to preallocate buffers and
|
||||
collect samples in time-major fashion (TxBx...).
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.callbacks = callbacks
|
||||
if num_agents == float("inf") or num_agents is None:
|
||||
num_agents = 1000
|
||||
self.num_agents = int(num_agents)
|
||||
|
||||
# Collect SampleBatches per-policy in _PerPolicySampleCollectors.
|
||||
self.policy_sample_collectors = {}
|
||||
for pid, policy in policy_map.items():
|
||||
# Figure out max-shifts (before and after).
|
||||
view_reqs = policy.training_view_requirements
|
||||
max_shift_before = 0
|
||||
max_shift_after = 0
|
||||
for vr in view_reqs.values():
|
||||
shift = force_list(vr.shift)
|
||||
if max_shift_before > shift[0]:
|
||||
max_shift_before = shift[0]
|
||||
if max_shift_after < shift[-1]:
|
||||
max_shift_after = shift[-1]
|
||||
# Figure out num_timesteps and num_agents.
|
||||
kwargs = {"time_major": time_major}
|
||||
if policy.is_recurrent():
|
||||
kwargs["num_timesteps"] = \
|
||||
policy.config["model"]["max_seq_len"]
|
||||
kwargs["time_major"] = True
|
||||
elif num_timesteps is not None:
|
||||
kwargs["num_timesteps"] = num_timesteps
|
||||
|
||||
self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
|
||||
num_agents=self.num_agents,
|
||||
shift_before=-max_shift_before,
|
||||
shift_after=max_shift_after,
|
||||
**kwargs)
|
||||
|
||||
# Internal agent-to-policy map.
|
||||
self.agent_to_policy = {}
|
||||
# Number of "inference" steps taken in the environment.
|
||||
# Regardless of the number of agents involved in each of these steps.
|
||||
self.count = 0
|
||||
|
||||
@override(_SampleCollector)
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
||||
env_id: EnvID, policy_id: PolicyID,
|
||||
obs: TensorType) -> None:
|
||||
# Make sure our mappings are up to date.
|
||||
if agent_id not in self.agent_to_policy:
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
else:
|
||||
assert self.agent_to_policy[agent_id] == policy_id
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
self.policy_sample_collectors[policy_id].add_init_obs(
|
||||
episode_id, agent_id, env_id, chunk_num=0, init_obs=obs)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
||||
agent_id: AgentID, env_id: EnvID,
|
||||
policy_id: PolicyID, agent_done: bool,
|
||||
values: Dict[str, TensorType]) -> None:
|
||||
assert policy_id in self.policy_sample_collectors
|
||||
|
||||
# Make sure our mappings are up to date.
|
||||
if agent_id not in self.agent_to_policy:
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
else:
|
||||
assert self.agent_to_policy[agent_id] == policy_id
|
||||
|
||||
# Include the current agent id for multi-agent algorithms.
|
||||
if agent_id != _DUMMY_AGENT_ID:
|
||||
values["agent_id"] = agent_id
|
||||
|
||||
# Add action/reward/next-obs (and other data) to Trajectory.
|
||||
self.policy_sample_collectors[policy_id].add_action_reward_next_obs(
|
||||
episode_id, agent_id, env_id, agent_done, values)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def total_env_steps(self) -> int:
|
||||
return sum(a.timesteps_since_last_reset
|
||||
for a in self.policy_sample_collectors.values())
|
||||
|
||||
def total(self):
|
||||
# TODO: (sven) deprecate; use `self.total_env_steps`, instead.
|
||||
# Sampler is currently still using `total()`.
|
||||
return self.total_env_steps()
|
||||
|
||||
@override(_SampleCollector)
|
||||
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
||||
Dict[str, TensorType]:
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements
|
||||
return self.policy_sample_collectors[
|
||||
policy_id].get_inference_input_dict(view_reqs)
|
||||
|
||||
@override(_SampleCollector)
|
||||
def has_non_postprocessed_data(self) -> bool:
|
||||
return self.total_env_steps() > 0
|
||||
|
||||
@override(_SampleCollector)
|
||||
def postprocess_trajectories_so_far(
|
||||
self, episode: Optional[MultiAgentEpisode] = None) -> None:
|
||||
# Loop through each per-policy collector and create a view (for each
|
||||
# agent as SampleBatch) from its buffers for post-processing
|
||||
all_agent_batches = {}
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
policy = self.policy_map[pid]
|
||||
view_reqs = policy.training_view_requirements
|
||||
agent_batches = rc.get_postprocessing_sample_batches(
|
||||
episode, view_reqs)
|
||||
|
||||
for agent_key, batch in agent_batches.items():
|
||||
other_batches = None
|
||||
if len(agent_batches) > 1:
|
||||
other_batches = agent_batches.copy()
|
||||
del other_batches[agent_key]
|
||||
|
||||
agent_batches[agent_key] = policy.postprocess_trajectory(
|
||||
batch, other_batches, episode)
|
||||
# Call the Policy's Exploration's postprocess method.
|
||||
if getattr(policy, "exploration", None) is not None:
|
||||
agent_batches[
|
||||
agent_key] = policy.exploration.postprocess_trajectory(
|
||||
policy, agent_batches[agent_key],
|
||||
getattr(policy, "_sess", None))
|
||||
|
||||
# Add new columns' data to buffers.
|
||||
for col in agent_batches[agent_key].new_columns:
|
||||
data = agent_batches[agent_key].data[col]
|
||||
rc._build_buffers({col: data[0]})
|
||||
timesteps = data.shape[0]
|
||||
rc.buffers[col][rc.shift_before:rc.shift_before +
|
||||
timesteps, rc.agent_key_to_slot[
|
||||
agent_key]] = data
|
||||
|
||||
all_agent_batches.update(agent_batches)
|
||||
|
||||
if log_once("after_post"):
|
||||
logger.info("Trajectory fragment after postprocess_trajectory():"
|
||||
"\n\n{}\n".format(summarize(all_agent_batches)))
|
||||
|
||||
# Append into policy batches and reset
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
for agent_key, batch in sorted(all_agent_batches.items()):
|
||||
self.callbacks.on_postprocess_trajectory(
|
||||
worker=get_global_worker(),
|
||||
episode=episode,
|
||||
agent_id=agent_key[0],
|
||||
policy_id=self.agent_to_policy[agent_key[0]],
|
||||
policies=self.policy_map,
|
||||
postprocessed_batch=batch,
|
||||
original_batches=None) # TODO: (sven) do we really need this?
|
||||
|
||||
@override(_SampleCollector)
|
||||
def check_missing_dones(self, episode_id: EpisodeID) -> None:
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
for agent_key in rc.agent_key_to_slot.keys():
|
||||
# Only check for given episode and only for last chunk
|
||||
# (all previous chunks for that agent in the episode are
|
||||
# non-terminal).
|
||||
if (agent_key[1] == episode_id
|
||||
and rc.agent_key_to_chunk_num[agent_key[:2]] ==
|
||||
agent_key[2]):
|
||||
t = rc.agent_key_to_timestep[agent_key] - 1
|
||||
b = rc.agent_key_to_slot[agent_key]
|
||||
if not rc.buffers["dones"][t][b]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we "
|
||||
"still don't have a last observation for "
|
||||
"agent {} (policy {}). ".format(agent_key[0], pid)
|
||||
+ "Please ensure that you include the last "
|
||||
"observations of all live agents when setting "
|
||||
"'__all__' done to True. Alternatively, set "
|
||||
"no_done_at_end=True to allow this.")
|
||||
|
||||
@override(_SampleCollector)
|
||||
def get_multi_agent_batch_and_reset(self):
|
||||
self.postprocess_trajectories_so_far()
|
||||
policy_batches = {}
|
||||
for pid, rc in self.policy_sample_collectors.items():
|
||||
policy = self.policy_map[pid]
|
||||
view_reqs = policy.training_view_requirements
|
||||
policy_batches[pid] = rc.get_train_sample_batch_and_reset(
|
||||
view_reqs)
|
||||
|
||||
ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count)
|
||||
# Reset our across-all-agents env step count.
|
||||
self.count = 0
|
||||
return ma_batch
|
|
@ -1,499 +0,0 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _PerPolicySampleCollector:
|
||||
"""A class for efficiently collecting samples for a single (fixed) policy.
|
||||
|
||||
Can be used by a _MultiAgentSampleCollector for its different policies.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_agents: Optional[int] = None,
|
||||
num_timesteps: Optional[int] = None,
|
||||
time_major: bool = True,
|
||||
shift_before: int = 0,
|
||||
shift_after: int = 0):
|
||||
"""Initializes a _PerPolicySampleCollector object.
|
||||
|
||||
Args:
|
||||
num_agents (int): The max number of agent slots to pre-allocate
|
||||
in the buffer.
|
||||
num_timesteps (int): The max number of timesteps to pre-allocate
|
||||
in the buffer.
|
||||
time_major (Optional[bool]): Whether to preallocate buffers and
|
||||
collect samples in time-major fashion (TxBx...).
|
||||
shift_before (int): The additional number of time slots to
|
||||
pre-allocate at the beginning of a time window (for possible
|
||||
underlying data column shifts, e.g. PREV_ACTIONS).
|
||||
shift_after (int): The additional number of time slots to
|
||||
pre-allocate at the end of a time window (for possible
|
||||
underlying data column shifts, e.g. NEXT_OBS).
|
||||
"""
|
||||
|
||||
self.num_agents = num_agents or 100
|
||||
self.num_timesteps = num_timesteps
|
||||
self.time_major = time_major
|
||||
# `shift_before must at least be 1 for the init obs timestep.
|
||||
self.shift_before = max(shift_before, 1)
|
||||
self.shift_after = shift_after
|
||||
|
||||
# The offset on the agent dim to start the next SampleBatch build from.
|
||||
self.sample_batch_offset = 0
|
||||
|
||||
# The actual underlying data-buffers.
|
||||
self.buffers = {}
|
||||
self.postprocessed_agents = [False] * self.num_agents
|
||||
|
||||
# Next agent-slot to be used by a new agent/env combination.
|
||||
self.agent_slot_cursor = 0
|
||||
# Maps agent/episode ID/chunk-num to an agent slot.
|
||||
self.agent_key_to_slot = {}
|
||||
# Maps agent/episode ID to the last chunk-num.
|
||||
self.agent_key_to_chunk_num = {}
|
||||
# Maps agent slot number to agent keys.
|
||||
self.slot_to_agent_key = [None] * self.num_agents
|
||||
# Maps agent/episode ID/chunk-num to a time step cursor.
|
||||
self.agent_key_to_timestep = {}
|
||||
|
||||
# Total timesteps taken in the env over all agents since last reset.
|
||||
self.timesteps_since_last_reset = 0
|
||||
|
||||
# Indices (T,B) to pick from the buffers for the next forward pass.
|
||||
self.forward_pass_indices = [[], []]
|
||||
self.forward_pass_size = 0
|
||||
# Maps index from the forward pass batch to (agent_id, episode_id,
|
||||
# env_id) tuple.
|
||||
self.forward_pass_index_to_agent_info = {}
|
||||
self.agent_key_to_forward_pass_index = {}
|
||||
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
|
||||
env_id: EnvID, chunk_num: int,
|
||||
init_obs: TensorType) -> None:
|
||||
"""Adds a single initial observation (after env.reset()) to the buffer.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
||||
initial observation for.
|
||||
agent_id (AgentID): Unique ID for the agent we are adding the
|
||||
initial observation for.
|
||||
env_id (EnvID): The env ID to which `init_obs` belongs.
|
||||
chunk_num (int): The time-chunk number (0-based). Some episodes
|
||||
may last for longer than self.num_timesteps and therefore
|
||||
have to be chopped into chunks.
|
||||
init_obs (TensorType): Initial observation (after env.reset()).
|
||||
"""
|
||||
agent_key = (agent_id, episode_id, chunk_num)
|
||||
agent_slot = self.agent_slot_cursor
|
||||
self.agent_key_to_slot[agent_key] = agent_slot
|
||||
self.agent_key_to_chunk_num[agent_key[:2]] = chunk_num
|
||||
self.slot_to_agent_key[agent_slot] = agent_key
|
||||
self._next_agent_slot()
|
||||
|
||||
if SampleBatch.OBS not in self.buffers:
|
||||
self._build_buffers(
|
||||
single_row={
|
||||
SampleBatch.OBS: init_obs,
|
||||
SampleBatch.EPS_ID: episode_id,
|
||||
SampleBatch.AGENT_INDEX: agent_id,
|
||||
"env_id": env_id,
|
||||
})
|
||||
if self.time_major:
|
||||
self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \
|
||||
init_obs
|
||||
else:
|
||||
self.buffers[SampleBatch.OBS][agent_slot, self.shift_before-1] = \
|
||||
init_obs
|
||||
self.agent_key_to_timestep[agent_key] = self.shift_before
|
||||
|
||||
self._add_to_next_inference_call(agent_key, env_id, agent_slot,
|
||||
self.shift_before - 1)
|
||||
|
||||
def add_action_reward_next_obs(
|
||||
self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID,
|
||||
agent_done: bool, values: Dict[str, TensorType]) -> None:
|
||||
"""Add the given dictionary (row) of values to this batch.
|
||||
|
||||
Args:
|
||||
episode_id (EpisodeID): Unique ID for the episode we are adding the
|
||||
values for.
|
||||
agent_id (AgentID): Unique ID for the agent we are adding the
|
||||
values for.
|
||||
env_id (EnvID): The env ID to which the given data belongs.
|
||||
agent_done (bool): Whether next obs should not be used for an
|
||||
upcoming inference call. Default: False = next-obs should be
|
||||
used for upcoming inference.
|
||||
values (Dict[str, TensorType]): Data dict (interpreted as a single
|
||||
row) to be added to buffer. Must contain keys:
|
||||
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
|
||||
"""
|
||||
assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values
|
||||
and SampleBatch.NEXT_OBS in values
|
||||
and SampleBatch.DONES in values)
|
||||
|
||||
assert SampleBatch.OBS not in values
|
||||
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
|
||||
del values[SampleBatch.NEXT_OBS]
|
||||
|
||||
chunk_num = self.agent_key_to_chunk_num[(agent_id, episode_id)]
|
||||
agent_key = (agent_id, episode_id, chunk_num)
|
||||
agent_slot = self.agent_key_to_slot[agent_key]
|
||||
ts = self.agent_key_to_timestep[agent_key]
|
||||
for k, v in values.items():
|
||||
if k not in self.buffers:
|
||||
self._build_buffers(single_row=values)
|
||||
if self.time_major:
|
||||
self.buffers[k][ts, agent_slot] = v
|
||||
else:
|
||||
self.buffers[k][agent_slot, ts] = v
|
||||
self.agent_key_to_timestep[agent_key] += 1
|
||||
|
||||
# Time-axis is "full" -> Cut-over to new chunk (only if not DONE).
|
||||
if self.agent_key_to_timestep[
|
||||
agent_key] - self.shift_before == self.num_timesteps and \
|
||||
not values[SampleBatch.DONES]:
|
||||
self._new_chunk_from(agent_slot, agent_key,
|
||||
self.agent_key_to_timestep[agent_key])
|
||||
|
||||
self.timesteps_since_last_reset += 1
|
||||
|
||||
if not agent_done:
|
||||
self._add_to_next_inference_call(agent_key, env_id, agent_slot, ts)
|
||||
|
||||
def get_inference_input_dict(self, view_reqs: Dict[str, ViewRequirement]
|
||||
) -> Dict[str, TensorType]:
|
||||
"""Returns an input_dict for an (inference) forward pass.
|
||||
|
||||
The input_dict can then be used for action computations inside a
|
||||
Policy via `Policy.compute_actions_from_input_dict()`.
|
||||
|
||||
Args:
|
||||
view_reqs (Dict[str, ViewRequirement]): The view requirements
|
||||
dict to use.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
|
||||
for inference/training.
|
||||
|
||||
Examples:
|
||||
>>> obs, r, done, info = env.step(action)
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
|
||||
... "action": action, "obs": obs, "reward": r, "done": done
|
||||
... })
|
||||
>>> input_dict = collector.get_inference_input_dict(policy.model)
|
||||
>>> action = policy.compute_actions_from_input_dict(input_dict)
|
||||
>>> # repeat
|
||||
"""
|
||||
input_dict = {}
|
||||
for view_col, view_req in view_reqs.items():
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
if data_col not in self.buffers:
|
||||
self._build_buffers({data_col: view_req.space.sample()})
|
||||
|
||||
indices = self.forward_pass_indices
|
||||
if self.time_major:
|
||||
input_dict[view_col] = self.buffers[data_col][indices]
|
||||
else:
|
||||
if isinstance(view_req.shift, (list, tuple)):
|
||||
time_indices = \
|
||||
np.array(view_req.shift) + np.array(indices[0])
|
||||
input_dict[view_col] = self.buffers[data_col][indices[1],
|
||||
time_indices]
|
||||
else:
|
||||
input_dict[view_col] = \
|
||||
self.buffers[data_col][indices[1], indices[0]]
|
||||
|
||||
self._reset_inference_call()
|
||||
|
||||
return input_dict
|
||||
|
||||
def get_postprocessing_sample_batches(
|
||||
self,
|
||||
episode: MultiAgentEpisode,
|
||||
view_reqs: Dict[str, ViewRequirement]) -> \
|
||||
Dict[AgentID, SampleBatch]:
|
||||
"""Returns a SampleBatch object ready for postprocessing.
|
||||
|
||||
Args:
|
||||
episode (MultiAgentEpisode): The MultiAgentEpisode object to
|
||||
get the to-be-postprocessed SampleBatches for.
|
||||
view_reqs (Dict[str, ViewRequirement]): The view requirements dict
|
||||
to use for creating the SampleBatch from our buffers.
|
||||
|
||||
Returns:
|
||||
Dict[AgentID, SampleBatch]: The sample batch objects to be passed
|
||||
to `Policy.postprocess_trajectory()`.
|
||||
"""
|
||||
# Loop through all agents and create a SampleBatch
|
||||
# (as "view"; no copying).
|
||||
|
||||
# Construct the SampleBatch-dict.
|
||||
sample_batch_data = {}
|
||||
|
||||
range_ = self.agent_slot_cursor - self.sample_batch_offset
|
||||
if range_ < 0:
|
||||
range_ = self.num_agents + range_
|
||||
for i in range(range_):
|
||||
agent_slot = self.sample_batch_offset + i
|
||||
if agent_slot >= self.num_agents:
|
||||
agent_slot = agent_slot % self.num_agents
|
||||
# Do not postprocess the same slot twice.
|
||||
if self.postprocessed_agents[agent_slot]:
|
||||
continue
|
||||
agent_key = self.slot_to_agent_key[agent_slot]
|
||||
# Skip other episodes (if episode provided).
|
||||
if episode and agent_key[1] != episode.episode_id:
|
||||
continue
|
||||
end = self.agent_key_to_timestep[agent_key]
|
||||
# Do not build any empty SampleBatches.
|
||||
if end == self.shift_before:
|
||||
continue
|
||||
self.postprocessed_agents[agent_slot] = True
|
||||
|
||||
assert agent_key not in sample_batch_data
|
||||
sample_batch_data[agent_key] = {}
|
||||
batch = sample_batch_data[agent_key]
|
||||
|
||||
for view_col, view_req in view_reqs.items():
|
||||
data_col = view_req.data_col or view_col
|
||||
# Skip columns that will only get added through postprocessing
|
||||
# (these may not even exist yet).
|
||||
if data_col not in self.buffers:
|
||||
continue
|
||||
|
||||
shift = view_req.shift
|
||||
if data_col == SampleBatch.OBS:
|
||||
shift -= 1
|
||||
|
||||
batch[view_col] = self.buffers[data_col][
|
||||
self.shift_before + shift:end + shift, agent_slot]
|
||||
|
||||
batches = {}
|
||||
for agent_key, data in sample_batch_data.items():
|
||||
batches[agent_key] = SampleBatch(data)
|
||||
return batches
|
||||
|
||||
def get_train_sample_batch_and_reset(self, view_reqs) -> SampleBatch:
|
||||
"""Returns the accumulated sample batche for this policy.
|
||||
|
||||
This is usually called to collect samples for policy training.
|
||||
|
||||
Returns:
|
||||
SampleBatch: Returns the accumulated sample batch for this
|
||||
policy.
|
||||
"""
|
||||
seq_lens_w_0s = [
|
||||
self.agent_key_to_timestep[k] - self.shift_before
|
||||
for k in self.slot_to_agent_key if k is not None
|
||||
]
|
||||
# We have an agent-axis buffer "rollover" (new SampleBatch will be
|
||||
# built from last n agent records plus first m agent records in
|
||||
# buffer).
|
||||
if self.agent_slot_cursor < self.sample_batch_offset:
|
||||
rollover = -(self.num_agents - self.sample_batch_offset)
|
||||
seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover]
|
||||
first_zero_len = len(seq_lens_w_0s)
|
||||
if seq_lens_w_0s[-1] == 0:
|
||||
first_zero_len = seq_lens_w_0s.index(0)
|
||||
# Assert that all zeros lie at the end of the seq_lens array.
|
||||
assert all(seq_lens_w_0s[i] == 0
|
||||
for i in range(first_zero_len, len(seq_lens_w_0s)))
|
||||
|
||||
t_start = self.shift_before
|
||||
t_end = t_start + self.num_timesteps
|
||||
|
||||
# The agent_slot cursor that points to the newest agent-slot that
|
||||
# actually already has at least 1 timestep of data (thus it excludes
|
||||
# just-rolled over chunks (which only have the initial obs in them)).
|
||||
valid_agent_cursor = \
|
||||
(self.agent_slot_cursor -
|
||||
(len(seq_lens_w_0s) - first_zero_len)) % self.num_agents
|
||||
|
||||
# Construct the view dict.
|
||||
view = {}
|
||||
for view_col, view_req in view_reqs.items():
|
||||
data_col = view_req.data_col or view_col
|
||||
assert data_col in self.buffers
|
||||
# For OBS, indices must be shifted by -1.
|
||||
shift = view_req.shift
|
||||
shift += 0 if data_col != SampleBatch.OBS else -1
|
||||
# If agent_slot has been rolled-over to beginning, we have to copy
|
||||
# here.
|
||||
if valid_agent_cursor < self.sample_batch_offset:
|
||||
time_slice = self.buffers[data_col][t_start + shift:t_end +
|
||||
shift]
|
||||
one_ = time_slice[:, self.sample_batch_offset:]
|
||||
two_ = time_slice[:, :valid_agent_cursor]
|
||||
if torch and isinstance(time_slice, torch.Tensor):
|
||||
view[view_col] = torch.cat([one_, two_], dim=1)
|
||||
else:
|
||||
view[view_col] = np.concatenate([one_, two_], axis=1)
|
||||
else:
|
||||
view[view_col] = \
|
||||
self.buffers[data_col][
|
||||
t_start + shift:t_end + shift,
|
||||
self.sample_batch_offset:valid_agent_cursor]
|
||||
|
||||
# Copy all still ongoing trajectories to new agent slots
|
||||
# (including the ones that just started (are seq_len=0)).
|
||||
new_chunk_args = []
|
||||
for i, seq_len in enumerate(seq_lens_w_0s):
|
||||
if seq_len < self.num_timesteps:
|
||||
agent_slot = (self.sample_batch_offset + i) % self.num_agents
|
||||
if not self.buffers[SampleBatch.
|
||||
DONES][seq_len - 1 +
|
||||
self.shift_before][agent_slot]:
|
||||
agent_key = self.slot_to_agent_key[agent_slot]
|
||||
new_chunk_args.append(
|
||||
(agent_slot, agent_key,
|
||||
self.agent_key_to_timestep[agent_key]))
|
||||
# Cut out all 0 seq-lens.
|
||||
seq_lens = seq_lens_w_0s[:first_zero_len]
|
||||
batch = SampleBatch(
|
||||
view, _seq_lens=np.array(seq_lens), _time_major=self.time_major)
|
||||
|
||||
# Reset everything for new data.
|
||||
self.postprocessed_agents = [False] * self.num_agents
|
||||
self.agent_key_to_slot.clear()
|
||||
self.agent_key_to_chunk_num.clear()
|
||||
self.slot_to_agent_key = [None] * self.num_agents
|
||||
self.agent_key_to_timestep.clear()
|
||||
self.timesteps_since_last_reset = 0
|
||||
self.forward_pass_size = 0
|
||||
self.sample_batch_offset = self.agent_slot_cursor
|
||||
|
||||
for args in new_chunk_args:
|
||||
self._new_chunk_from(*args)
|
||||
|
||||
return batch
|
||||
|
||||
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
|
||||
"""Builds the internal data buffers based on a single given row.
|
||||
|
||||
This may be called several times in the lifetime of this instance
|
||||
to add new columns to the buffer. Columns in `single_row` that already
|
||||
exist in the buffer will be ignored.
|
||||
|
||||
Args:
|
||||
single_row (Dict[str, TensorType]): A single datarow with one or
|
||||
more columns (str as key, np.ndarray|tensor as data) to be used
|
||||
as template to build the pre-allocated buffer.
|
||||
"""
|
||||
time_size = self.num_timesteps + self.shift_before + self.shift_after
|
||||
for col, data in single_row.items():
|
||||
if col in self.buffers:
|
||||
continue
|
||||
base_shape = (time_size, self.num_agents) if self.time_major else \
|
||||
(self.num_agents, time_size)
|
||||
# Python primitive -> np.array.
|
||||
if isinstance(data, (int, float, bool)):
|
||||
t_ = type(data)
|
||||
dtype = np.float32 if t_ == float else \
|
||||
np.int32 if type(data) == int else np.bool_
|
||||
self.buffers[col] = np.zeros(shape=base_shape, dtype=dtype)
|
||||
# np.ndarray, torch.Tensor, or tf.Tensor.
|
||||
else:
|
||||
shape = base_shape + data.shape
|
||||
dtype = data.dtype
|
||||
if torch and isinstance(data, torch.Tensor):
|
||||
self.buffers[col] = torch.zeros(
|
||||
*shape, dtype=dtype, device=data.device)
|
||||
elif tf and isinstance(data, tf.Tensor):
|
||||
self.buffers[col] = tf.zeros(shape=shape, dtype=dtype)
|
||||
else:
|
||||
self.buffers[col] = np.zeros(shape=shape, dtype=dtype)
|
||||
|
||||
def _next_agent_slot(self):
|
||||
"""Starts a new agent slot at the end of the agent-axis.
|
||||
|
||||
Also makes sure, the new slot is not taken yet.
|
||||
"""
|
||||
self.agent_slot_cursor += 1
|
||||
if self.agent_slot_cursor >= self.num_agents:
|
||||
self.agent_slot_cursor = 0
|
||||
# Just make sure, there is space in our buffer.
|
||||
assert self.slot_to_agent_key[self.agent_slot_cursor] is None
|
||||
|
||||
def _new_chunk_from(self, agent_slot, agent_key, timestep):
|
||||
"""Creates a new time-window (chunk) given an agent.
|
||||
|
||||
The agent may already have an unfinished episode going on (in a
|
||||
previous chunk). The end of that previous chunk will be copied to the
|
||||
beginning of the new one for proper data-shift handling (e.g.
|
||||
PREV_ACTIONS/REWARDS).
|
||||
|
||||
Args:
|
||||
agent_slot (int): The agent to start a new chunk for (from an
|
||||
ongoing episode (chunk)).
|
||||
agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to
|
||||
identify an active agent in some episode.
|
||||
timestep (int): The timestep in the old chunk being continued.
|
||||
"""
|
||||
new_agent_slot = self.agent_slot_cursor
|
||||
# Increase chunk num by 1.
|
||||
new_agent_key = agent_key[:2] + (agent_key[2] + 1, )
|
||||
# Copy relevant timesteps at end of old chunk into new one.
|
||||
if self.time_major:
|
||||
for k in self.buffers.keys():
|
||||
self.buffers[k][0:self.shift_before, new_agent_slot] = \
|
||||
self.buffers[k][
|
||||
timestep - self.shift_before:timestep, agent_slot]
|
||||
else:
|
||||
for k in self.buffers.keys():
|
||||
self.buffers[k][new_agent_slot, 0:self.shift_before] = \
|
||||
self.buffers[k][
|
||||
agent_slot, timestep - self.shift_before:timestep]
|
||||
|
||||
self.agent_key_to_slot[new_agent_key] = new_agent_slot
|
||||
self.agent_key_to_chunk_num[new_agent_key[:2]] = new_agent_key[2]
|
||||
self.slot_to_agent_key[new_agent_slot] = new_agent_key
|
||||
self._next_agent_slot()
|
||||
self.agent_key_to_timestep[new_agent_key] = self.shift_before
|
||||
|
||||
def _add_to_next_inference_call(self, agent_key, env_id, agent_slot,
|
||||
timestep):
|
||||
"""Registers given T and B (agent_slot) for get_inference_input_dict.
|
||||
|
||||
Calling `get_inference_input_dict` will produce an input_dict (for
|
||||
Policy.compute_actions_from_input_dict) with all registered agent/time
|
||||
indices and then automatically reset the registry.
|
||||
|
||||
Args:
|
||||
agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to
|
||||
identify an active agent in some episode.
|
||||
env_id (EnvID): The env ID of the given agent.
|
||||
agent_slot (int): The agent_slot to register (B axis).
|
||||
timestep (int): The timestep to register (T axis).
|
||||
"""
|
||||
idx = self.forward_pass_size
|
||||
self.forward_pass_index_to_agent_info[idx] = (agent_key[0],
|
||||
agent_key[1], env_id)
|
||||
self.agent_key_to_forward_pass_index[agent_key[:2]] = idx
|
||||
if self.forward_pass_size == 0:
|
||||
self.forward_pass_indices[0].clear()
|
||||
self.forward_pass_indices[1].clear()
|
||||
self.forward_pass_indices[0].append(timestep)
|
||||
self.forward_pass_indices[1].append(agent_slot)
|
||||
self.forward_pass_size += 1
|
||||
|
||||
def _reset_inference_call(self):
|
||||
"""Resets indices for the next inference call.
|
||||
|
||||
After calling this, new calls to `add_init_obs()` and
|
||||
`add_action_reward_next_obs()` will count for the next input_dict
|
||||
returned by `get_inference_input_dict()`.
|
||||
"""
|
||||
self.forward_pass_size = 0
|
|
@ -9,13 +9,14 @@ from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
|
|||
TYPE_CHECKING, Union
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.collectors.sample_collector import \
|
||||
_SampleCollector
|
||||
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
||||
_SimpleListCollector
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.multi_agent_sample_collector import \
|
||||
_MultiAgentSampleCollector
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.sample_collector import _SampleCollector
|
||||
from ray.rllib.policy.policy import clip_action, Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
@ -188,8 +189,9 @@ class SyncSampler(SamplerInput):
|
|||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = _PerfStats()
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _MultiAgentSampleCollector(
|
||||
policies, callbacks)
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
|
||||
rollout_fragment_length)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
|
@ -333,8 +335,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.observation_fn = observation_fn
|
||||
self._use_trajectory_view_api = _use_trajectory_view_api
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _MultiAgentSampleCollector(
|
||||
policies, callbacks)
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
|
||||
rollout_fragment_length)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
|
@ -537,7 +540,6 @@ def _env_runner(
|
|||
|
||||
active_episodes: Dict[str, MultiAgentEpisode] = \
|
||||
NewEpisodeDefaultDict(new_episode)
|
||||
eval_results = None
|
||||
|
||||
while True:
|
||||
perf_stats.iters += 1
|
||||
|
@ -564,7 +566,6 @@ def _env_runner(
|
|||
base_env=base_env,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
prev_policy_outputs=eval_results,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
|
@ -572,13 +573,11 @@ def _env_runner(
|
|||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
perf_stats=perf_stats,
|
||||
_sample_collector=_sample_collector,
|
||||
)
|
||||
else:
|
||||
|
@ -601,7 +600,6 @@ def _env_runner(
|
|||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
perf_stats=perf_stats,
|
||||
)
|
||||
perf_stats.raw_obs_processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
|
@ -669,7 +667,6 @@ def _process_observations(
|
|||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
perf_stats: _PerfStats,
|
||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||
RolloutMetrics, SampleBatchType]]]:
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
@ -682,8 +679,6 @@ def _process_observations(
|
|||
SampleBatchBuilder object for recycling.
|
||||
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
||||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
prev_policy_outputs (Dict[str,List]): The prev policy output dict
|
||||
(by policy-id -> List[action, state outs, extra fetches]).
|
||||
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
|
||||
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
|
||||
call.
|
||||
|
@ -862,6 +857,7 @@ def _process_observations(
|
|||
# and add it to "outputs".
|
||||
if (all_agents_done and not multiple_episodes_in_batch) or \
|
||||
batch_builder.count >= rollout_fragment_length:
|
||||
batch_builder.postprocess_batch_so_far(episode)
|
||||
outputs.append(batch_builder.build_and_reset(episode))
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
|
@ -887,10 +883,14 @@ def _process_observations(
|
|||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
# Horizon hit and we have a soft horizon (no hard env reset).
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
|
||||
# Env actually ended OR horizon hit and no soft horizon ->
|
||||
# Try hard env-reset.
|
||||
else:
|
||||
# Remove episode from active ones.
|
||||
del active_episodes[env_id]
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
|
||||
env_id)
|
||||
|
@ -939,8 +939,6 @@ def _process_observations_w_trajectory_view_api(
|
|||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
|
||||
dict]],
|
||||
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
||||
rewards: Dict[EnvID, Dict[AgentID, float]],
|
||||
dones: Dict[EnvID, Dict[AgentID, bool]],
|
||||
|
@ -948,13 +946,11 @@ def _process_observations_w_trajectory_view_api(
|
|||
horizon: int,
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
rollout_fragment_length: int,
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
perf_stats: _PerfStats,
|
||||
_sample_collector: _SampleCollector,
|
||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||
RolloutMetrics, SampleBatchType]]]:
|
||||
|
@ -964,41 +960,20 @@ def _process_observations_w_trajectory_view_api(
|
|||
|
||||
# Output objects.
|
||||
active_envs: Set[EnvID] = set()
|
||||
to_eval: Set[PolicyID] = set()
|
||||
to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
|
||||
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
|
||||
|
||||
large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
|
||||
rollout_fragment_length != float("inf") else 5000
|
||||
|
||||
# For each environment.
|
||||
# For each (vectorized) sub-environment.
|
||||
# type: EnvID, Dict[AgentID, EnvObsType]
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
for env_id, all_agents_obs in unfiltered_obs.items():
|
||||
is_new_episode: bool = env_id not in active_episodes
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
|
||||
if not is_new_episode:
|
||||
_sample_collector.episode_step(episode.episode_id)
|
||||
episode.length += 1
|
||||
_sample_collector.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
if (_sample_collector.total_env_steps() > large_batch_threshold
|
||||
and log_once("large_batch_warning")):
|
||||
logger.warning(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
_sample_collector.total_env_steps(),
|
||||
_sample_collector.count) +
|
||||
"are buffered in the sampler. If this is more than you "
|
||||
"expected, check that that you set a horizon on your "
|
||||
"environment correctly and that it terminates at some point. "
|
||||
"Note: In multi-agent environments, `rollout_fragment_length` "
|
||||
"sets the batch size based on (across-agents) environment "
|
||||
"steps, not the steps of individual agents, which can result "
|
||||
"in unexpectedly large batches." +
|
||||
("Also, you may be in evaluation waiting for your Env to "
|
||||
"terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||
"does at some point."
|
||||
if not multiple_episodes_in_batch else ""))
|
||||
|
||||
# Check episode termination conditions.
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
hit_horizon = (episode.length >= horizon
|
||||
|
@ -1023,19 +998,19 @@ def _process_observations_w_trajectory_view_api(
|
|||
|
||||
# Custom observation function is applied before preprocessing.
|
||||
if observation_fn:
|
||||
agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||
agent_obs=agent_obs,
|
||||
all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||
agent_obs=all_agents_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
if not isinstance(agent_obs, dict):
|
||||
if not isinstance(all_agents_obs, dict):
|
||||
raise ValueError(
|
||||
"observe() must return a dict of agent observations")
|
||||
|
||||
# For each agent in the environment.
|
||||
# type: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
for agent_id, raw_obs in all_agents_obs.items():
|
||||
assert agent_id != "__all__"
|
||||
policy_id: PolicyID = episode.policy_for(agent_id)
|
||||
prep_obs: EnvObsType = _get_or_raise(preprocessors,
|
||||
|
@ -1058,38 +1033,41 @@ def _process_observations_w_trajectory_view_api(
|
|||
|
||||
# Record transition info if applicable.
|
||||
if last_observation is None:
|
||||
_sample_collector.add_init_obs(episode.episode_id, agent_id,
|
||||
env_id, policy_id, filtered_obs)
|
||||
_sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||
policy_id, filtered_obs)
|
||||
else:
|
||||
rc = _sample_collector.policy_sample_collectors[policy_id]
|
||||
eval_idx = rc.agent_key_to_forward_pass_index[(
|
||||
agent_id, episode.episode_id)]
|
||||
# Add actions, rewards, next-obs to collectors.
|
||||
values_dict = {
|
||||
"t": episode.length - 1,
|
||||
"eps_id": episode.episode_id,
|
||||
"env_id": env_id,
|
||||
"agent_index": episode._agent_index(agent_id),
|
||||
# Action (slot 0) taken at timestep t.
|
||||
"actions": prev_policy_outputs[policy_id][0][eval_idx],
|
||||
"actions": episode.last_action_for(agent_id),
|
||||
# Reward received after taking a at timestep t.
|
||||
"rewards": rewards[env_id][agent_id],
|
||||
# After taking a, did we reach terminal?
|
||||
# After taking action=a, did we reach terminal?
|
||||
"dones": (False if (no_done_at_end
|
||||
or (hit_horizon and soft_horizon)) else
|
||||
agent_done),
|
||||
# Next observation.
|
||||
"new_obs": filtered_obs,
|
||||
}
|
||||
# TODO: (sven) add env infos to buffers as well.
|
||||
for k, v in prev_policy_outputs[policy_id][2].items():
|
||||
values_dict[k] = v[eval_idx]
|
||||
for i, v in enumerate(prev_policy_outputs[policy_id][1]):
|
||||
values_dict["state_out_{}".format(i)] = v[eval_idx]
|
||||
# Add extra-action-fetches to collectors.
|
||||
values_dict.update(**episode.last_pi_info_for(agent_id))
|
||||
_sample_collector.add_action_reward_next_obs(
|
||||
episode.episode_id, agent_id, env_id, policy_id,
|
||||
agent_done, values_dict)
|
||||
|
||||
if not agent_done:
|
||||
to_eval.add(policy_id)
|
||||
item = PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs, infos[env_id].get(
|
||||
agent_id, {}), None if last_observation is None else
|
||||
episode.rnn_state_for(agent_id), None
|
||||
if last_observation is None else
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0)
|
||||
to_eval[policy_id].append(item)
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
callbacks.on_episode_step(
|
||||
|
@ -1098,36 +1076,22 @@ def _process_observations_w_trajectory_view_api(
|
|||
episode=episode,
|
||||
env_index=env_id)
|
||||
|
||||
# Cut the batch if ...
|
||||
# - all-agents-done and not packing multiple episodes into one
|
||||
# (batch_mode="complete_episodes")
|
||||
# - or if we've exceeded the rollout_fragment_length.
|
||||
if _sample_collector.has_non_postprocessed_data():
|
||||
# Sanity check, whether all agents have done=True, if done[__all__]
|
||||
# is True.
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
_sample_collector.check_missing_dones(
|
||||
episode_id=episode.episode_id)
|
||||
|
||||
# Reached end of episode and we are not allowed to pack the
|
||||
# next episode into the same SampleBatch -> Build the SampleBatch
|
||||
# and add it to "outputs".
|
||||
if (all_agents_done and not multiple_episodes_in_batch) or \
|
||||
_sample_collector.count >= rollout_fragment_length:
|
||||
# TODO: (sven) Case: rollout_fragment_length reached: Do not
|
||||
# store any data in `episode` anymore
|
||||
# (useless for get_view_requirements when t<<-1, e.g.
|
||||
# attention), but keep last episode data around in
|
||||
# SampleBatchBuilder
|
||||
# to be able to still reference into it
|
||||
# should a model require this.
|
||||
outputs.append(_sample_collector.get_multi_agent_batch_and_reset())
|
||||
# Episode is done for all agents
|
||||
# (dones[__all__] == True or hit horizon).
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
_sample_collector.postprocess_trajectories_so_far(episode)
|
||||
|
||||
# Episode is done.
|
||||
if all_agents_done:
|
||||
is_done = dones[env_id]["__all__"]
|
||||
check_dones = is_done and not no_done_at_end
|
||||
_sample_collector.postprocess_episode(
|
||||
episode, is_done=is_done, check_dones=check_dones)
|
||||
# We are not allowed to pack the next episode into the same
|
||||
# SampleBatch (batch_mode=complete_episodes) -> Build the
|
||||
# MultiAgentBatch from a single episode and add it to "outputs".
|
||||
if not multiple_episodes_in_batch:
|
||||
ma_sample_batch = \
|
||||
_sample_collector.build_multi_agent_batch(episode.length)
|
||||
outputs.append(ma_sample_batch)
|
||||
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in policies.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
|
@ -1144,44 +1108,56 @@ def _process_observations_w_trajectory_view_api(
|
|||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
# Horizon hit and we have a soft horizon (no hard env reset).
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
|
||||
else:
|
||||
del active_episodes[env_id]
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
|
||||
env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list.
|
||||
if resetted_obs is None:
|
||||
if horizon != float("inf"):
|
||||
raise ValueError(
|
||||
"Setting episode horizon requires reset() support "
|
||||
"from the environment.")
|
||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||
# Creates a new episode if this is not async return.
|
||||
# If reset is async, we will get its result in some future poll
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
# If reset is async, we will get its result in some future poll.
|
||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||
new_episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
if observation_fn:
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||
agent_obs=resetted_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
episode=new_episode)
|
||||
# type: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id: PolicyID = episode.policy_for(agent_id)
|
||||
policy_id: PolicyID = new_episode.policy_for(agent_id)
|
||||
prep_obs: EnvObsType = _get_or_raise(
|
||||
preprocessors, policy_id).transform(raw_obs)
|
||||
filtered_obs: EnvObsType = _get_or_raise(
|
||||
obs_filters, policy_id)(prep_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
new_episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Add initial obs to buffer.
|
||||
_sample_collector.add_init_obs(episode.episode_id,
|
||||
agent_id, env_id, policy_id,
|
||||
filtered_obs)
|
||||
to_eval.add(policy_id)
|
||||
_sample_collector.add_init_obs(
|
||||
new_episode, agent_id, env_id, policy_id, filtered_obs)
|
||||
|
||||
item = PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.last_info_for(agent_id) or {},
|
||||
episode.rnn_state_for(agent_id), None, 0.0)
|
||||
to_eval[policy_id].append(item)
|
||||
|
||||
# Try to build something.
|
||||
if multiple_episodes_in_batch:
|
||||
sample_batch = \
|
||||
_sample_collector.try_build_truncated_episode_multi_agent_batch()
|
||||
if sample_batch is not None:
|
||||
outputs.append(sample_batch)
|
||||
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
@ -1306,7 +1282,7 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
||||
for policy_id in to_eval:
|
||||
for policy_id in to_eval.keys():
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
||||
eval_results[policy_id] = \
|
||||
|
@ -1373,7 +1349,7 @@ def _process_policy_eval_results(
|
|||
actions_to_send[env_id] = {} # at minimum send empty dict
|
||||
|
||||
# type: PolicyID, List[PolicyEvalData]
|
||||
for policy_id in to_eval:
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
actions: TensorStructType = eval_results[policy_id][0]
|
||||
actions = convert_to_numpy(actions)
|
||||
|
||||
|
@ -1385,10 +1361,11 @@ def _process_policy_eval_results(
|
|||
if isinstance(actions, list):
|
||||
actions = np.array(actions)
|
||||
|
||||
# Add RNN state info.
|
||||
eval_data = None
|
||||
if not _use_trajectory_view_api:
|
||||
eval_data = to_eval[policy_id]
|
||||
# Store RNN state ins/outs and extra-action fetches to episode.
|
||||
if _use_trajectory_view_api:
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
else:
|
||||
rnn_in_cols: StateBatch = _to_column_format(
|
||||
[t.rnn_state for t in eval_data])
|
||||
|
||||
|
@ -1413,14 +1390,6 @@ def _process_policy_eval_results(
|
|||
else:
|
||||
clipped_action = action
|
||||
|
||||
# Trajectory View API: Do not store data directly in episode
|
||||
# (entire episode is stored in Trajectory and kept until
|
||||
# end of episode).
|
||||
if _use_trajectory_view_api:
|
||||
agent_id, episode_id, env_id = \
|
||||
_sample_collector.policy_sample_collectors[
|
||||
policy_id].forward_pass_index_to_agent_info[i]
|
||||
else:
|
||||
env_id: int = eval_data[i].env_id
|
||||
agent_id: AgentID = eval_data[i].agent_id
|
||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
|
@ -1430,8 +1399,8 @@ def _process_policy_eval_results(
|
|||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode._set_last_action(
|
||||
agent_id, off_policy_actions[env_id][agent_id])
|
||||
episode._set_last_action(agent_id,
|
||||
off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode._set_last_action(agent_id, action)
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
|||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.examples.policy.episode_env_aware_policy import \
|
||||
EpisodeEnvAwarePolicy
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.rllib.utils.test_utils import framework_iterator, check
|
||||
|
||||
|
||||
class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
|
@ -30,9 +31,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.training_view_requirements
|
||||
assert len(view_req_model) == 1
|
||||
assert len(view_req_policy) == 10
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 1, view_req_model
|
||||
assert len(view_req_policy) == 11, view_req_policy
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
|
@ -62,9 +63,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.training_view_requirements
|
||||
assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates
|
||||
assert len(view_req_policy) == 16
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 7, view_req_model
|
||||
assert len(view_req_policy) == 17, view_req_policy
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
|
@ -90,7 +91,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_lstm_performance(self):
|
||||
def test_traj_view_simple_performance(self):
|
||||
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
|
||||
"""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
|
@ -102,17 +103,15 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
from ray.tune import register_env
|
||||
register_env("ma_env", lambda c: RandomMultiAgentEnv({
|
||||
"num_agents": 2,
|
||||
"p_done": 0.01,
|
||||
"p_done": 0.0,
|
||||
"max_episode_len": 104,
|
||||
"action_space": action_space,
|
||||
"observation_space": obs_space
|
||||
}))
|
||||
|
||||
config["num_workers"] = 3
|
||||
config["num_envs_per_worker"] = 8
|
||||
config["num_sgd_iter"] = 6
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["lstm_use_prev_action_reward"] = True
|
||||
config["model"]["max_seq_len"] = 100
|
||||
config["num_sgd_iter"] = 1 # Put less weight on training.
|
||||
|
||||
policies = {
|
||||
"pol0": (None, obs_space, action_space, {}),
|
||||
|
@ -125,72 +124,80 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
"policies": policies,
|
||||
"policy_mapping_fn": policy_fn,
|
||||
}
|
||||
num_iterations = 1
|
||||
num_iterations = 2
|
||||
# Only works in torch so far.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
print("w/ traj. view API (and time-major)")
|
||||
print("w/ traj. view API")
|
||||
config["_use_trajectory_view_api"] = True
|
||||
config["model"]["_time_major"] = True
|
||||
trainer = ppo.PPOTrainer(config=config, env="ma_env")
|
||||
learn_time_w = 0.0
|
||||
sampler_perf = {}
|
||||
sampler_perf_w = {}
|
||||
start = time.time()
|
||||
for i in range(num_iterations):
|
||||
out = trainer.train()
|
||||
ts = out["timesteps_total"]
|
||||
sampler_perf_ = out["sampler_perf"]
|
||||
sampler_perf = {
|
||||
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
|
||||
sampler_perf_w = {
|
||||
k:
|
||||
sampler_perf_w.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts)
|
||||
for k, v in sampler_perf_.items()
|
||||
}
|
||||
delta = out["timers"]["learn_time_ms"] / 1000
|
||||
delta = out["timers"]["learn_time_ms"] / ts
|
||||
learn_time_w += delta
|
||||
print("{}={}s".format(i, delta))
|
||||
sampler_perf = {
|
||||
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf.items()
|
||||
sampler_perf_w = {
|
||||
k: sampler_perf_w[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf_w.items()
|
||||
}
|
||||
duration_w = time.time() - start
|
||||
print("Duration: {}s "
|
||||
"sampler-perf.={} learn-time/iter={}s".format(
|
||||
duration_w, sampler_perf, learn_time_w / num_iterations))
|
||||
duration_w, sampler_perf_w,
|
||||
learn_time_w / num_iterations))
|
||||
trainer.stop()
|
||||
|
||||
print("w/o traj. view API (and w/o time-major)")
|
||||
print("w/o traj. view API")
|
||||
config["_use_trajectory_view_api"] = False
|
||||
config["model"]["_time_major"] = False
|
||||
trainer = ppo.PPOTrainer(config=config, env="ma_env")
|
||||
learn_time_wo = 0.0
|
||||
sampler_perf = {}
|
||||
sampler_perf_wo = {}
|
||||
start = time.time()
|
||||
for i in range(num_iterations):
|
||||
out = trainer.train()
|
||||
ts = out["timesteps_total"]
|
||||
sampler_perf_ = out["sampler_perf"]
|
||||
sampler_perf = {
|
||||
k: sampler_perf.get(k, 0.0) + sampler_perf_[k]
|
||||
sampler_perf_wo = {
|
||||
k: sampler_perf_wo.get(k, 0.0) +
|
||||
(sampler_perf_[k] * 1000 / ts)
|
||||
for k, v in sampler_perf_.items()
|
||||
}
|
||||
delta = out["timers"]["learn_time_ms"] / 1000
|
||||
delta = out["timers"]["learn_time_ms"] / ts
|
||||
learn_time_wo += delta
|
||||
print("{}={}s".format(i, delta))
|
||||
sampler_perf = {
|
||||
k: sampler_perf[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf.items()
|
||||
sampler_perf_wo = {
|
||||
k: sampler_perf_wo[k] / (num_iterations if "mean_" in k else 1)
|
||||
for k, v in sampler_perf_wo.items()
|
||||
}
|
||||
duration_wo = time.time() - start
|
||||
print("Duration: {}s "
|
||||
"sampler-perf.={} learn-time/iter={}s".format(
|
||||
duration_wo, sampler_perf,
|
||||
duration_wo, sampler_perf_wo,
|
||||
learn_time_wo / num_iterations))
|
||||
trainer.stop()
|
||||
|
||||
# Assert `_use_trajectory_view_api` is much faster.
|
||||
# Assert `_use_trajectory_view_api` is faster.
|
||||
self.assertLess(sampler_perf_w["mean_raw_obs_processing_ms"],
|
||||
sampler_perf_wo["mean_raw_obs_processing_ms"])
|
||||
self.assertLess(sampler_perf_w["mean_action_processing_ms"],
|
||||
sampler_perf_wo["mean_action_processing_ms"])
|
||||
self.assertLess(duration_w, duration_wo)
|
||||
self.assertLess(learn_time_w, learn_time_wo * 0.6)
|
||||
|
||||
def test_traj_view_lstm_functionality(self):
|
||||
action_space = Box(-float("inf"), float("inf"), shape=(2, ))
|
||||
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
|
||||
obs_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
max_seq_len = 50
|
||||
rollout_fragment_length = 200
|
||||
assert rollout_fragment_length % max_seq_len == 0
|
||||
policies = {
|
||||
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
|
||||
}
|
||||
|
@ -198,77 +205,162 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
rollout_worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config={
|
||||
config = {
|
||||
"multiagent": {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_fn,
|
||||
},
|
||||
"_use_trajectory_view_api": True,
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
"_time_major": True,
|
||||
"max_seq_len": max_seq_len,
|
||||
},
|
||||
},
|
||||
|
||||
rollout_worker_w_api = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
for i in range(100):
|
||||
pc = rollout_worker.sampler.sample_collector. \
|
||||
policy_sample_collectors["pol0"]
|
||||
sample_batch_offset_before = pc.sample_batch_offset
|
||||
buffers = pc.buffers
|
||||
result = rollout_worker.sample()
|
||||
pol_batch = result.policy_batches["pol0"]
|
||||
rollout_worker_wo_api = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config=dict(config, **{"_use_trajectory_view_api": False}),
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
for iteration in range(20):
|
||||
result = rollout_worker_w_api.sample()
|
||||
check(result.count, rollout_fragment_length)
|
||||
pol_batch_w = result.policy_batches["pol0"]
|
||||
assert pol_batch_w.count >= rollout_fragment_length
|
||||
analyze_rnn_batch(pol_batch_w, max_seq_len)
|
||||
|
||||
result = rollout_worker_wo_api.sample()
|
||||
pol_batch_wo = result.policy_batches["pol0"]
|
||||
check(pol_batch_w.data, pol_batch_wo.data)
|
||||
|
||||
|
||||
def analyze_rnn_batch(batch, max_seq_len):
|
||||
count = batch.count
|
||||
|
||||
self.assertTrue(result.count == 100)
|
||||
self.assertTrue(pol_batch.count >= 100)
|
||||
self.assertFalse(0 in pol_batch.seq_lens)
|
||||
# Check prev_reward/action, next_obs consistency.
|
||||
for t in range(max_seq_len):
|
||||
obs_t = pol_batch["obs"][t]
|
||||
r_t = pol_batch["rewards"][t]
|
||||
if t > 0:
|
||||
next_obs_t_m_1 = pol_batch["new_obs"][t - 1]
|
||||
self.assertTrue((obs_t == next_obs_t_m_1).all())
|
||||
if t < max_seq_len - 1:
|
||||
prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1]
|
||||
self.assertTrue((r_t == prev_rewards_t_p_1).all())
|
||||
for idx in range(count):
|
||||
# If timestep tracked by batch, good.
|
||||
if "t" in batch:
|
||||
ts = batch["t"][idx]
|
||||
# Else, ts
|
||||
else:
|
||||
ts = batch["obs"][idx][3]
|
||||
obs_t = batch["obs"][idx]
|
||||
a_t = batch["actions"][idx]
|
||||
r_t = batch["rewards"][idx]
|
||||
state_in_0 = batch["state_in_0"][idx]
|
||||
state_in_1 = batch["state_in_1"][idx]
|
||||
|
||||
# Check the sanity of all the buffers in the un underlying
|
||||
# PerPolicy collector.
|
||||
for sample_batch_slot, agent_slot in enumerate(
|
||||
range(sample_batch_offset_before, pc.sample_batch_offset)):
|
||||
t_buf = buffers["t"][:, agent_slot]
|
||||
obs_buf = buffers["obs"][:, agent_slot]
|
||||
# Skip empty seqs at end (these won't be part of the batch
|
||||
# and have been copied to new agent-slots (even if seq-len=0)).
|
||||
if sample_batch_slot < len(pol_batch.seq_lens):
|
||||
seq_len = pol_batch.seq_lens[sample_batch_slot]
|
||||
# Make sure timesteps are always increasing within the seq.
|
||||
assert all(t_buf[1] + j == n + 1
|
||||
for j, n in enumerate(t_buf)
|
||||
if j < seq_len and j != 0)
|
||||
# Make sure all obs within seq are non-0.0.
|
||||
assert all(
|
||||
any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1))
|
||||
# Check postprocessing outputs.
|
||||
if "postprocessed_column" in batch:
|
||||
postprocessed_col_t = batch["postprocessed_column"][idx]
|
||||
assert (obs_t == postprocessed_col_t / 2.0).all()
|
||||
|
||||
# Check seq-lens.
|
||||
for agent_slot, seq_len in enumerate(pol_batch.seq_lens):
|
||||
if seq_len < max_seq_len - 1:
|
||||
# At least in the beginning, the next slots should always
|
||||
# be empty (once all agent slots have been used once, these
|
||||
# may be filled with "old" values (from longer sequences)).
|
||||
if i < 10:
|
||||
self.assertTrue(
|
||||
(pol_batch["obs"][seq_len +
|
||||
1][agent_slot] == 0.0).all())
|
||||
print(end="")
|
||||
self.assertFalse(
|
||||
(pol_batch["obs"][seq_len][agent_slot] == 0.0).all())
|
||||
# Check state-in/out and next-obs values.
|
||||
if idx > 0:
|
||||
next_obs_t_m_1 = batch["new_obs"][idx - 1]
|
||||
state_out_0_t_m_1 = batch["state_out_0"][idx - 1]
|
||||
state_out_1_t_m_1 = batch["state_out_1"][idx - 1]
|
||||
# Same trajectory as for t-1 -> Should be able to match.
|
||||
if (batch[SampleBatch.AGENT_INDEX][idx] ==
|
||||
batch[SampleBatch.AGENT_INDEX][idx - 1]
|
||||
and batch[SampleBatch.EPS_ID][idx] ==
|
||||
batch[SampleBatch.EPS_ID][idx - 1]):
|
||||
assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx]
|
||||
assert (obs_t == next_obs_t_m_1).all()
|
||||
assert (state_in_0 == state_out_0_t_m_1).all()
|
||||
assert (state_in_1 == state_out_1_t_m_1).all()
|
||||
# Different trajectory.
|
||||
else:
|
||||
assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx]
|
||||
assert not (obs_t == next_obs_t_m_1).all()
|
||||
assert not (state_in_0 == state_out_0_t_m_1).all()
|
||||
assert not (state_in_1 == state_out_1_t_m_1).all()
|
||||
# Check initial 0-internal states.
|
||||
if ts == 0:
|
||||
assert (state_in_0 == 0.0).all()
|
||||
assert (state_in_1 == 0.0).all()
|
||||
|
||||
# Check initial 0-internal states (at ts=0).
|
||||
if ts == 0:
|
||||
assert (state_in_0 == 0.0).all()
|
||||
assert (state_in_1 == 0.0).all()
|
||||
|
||||
# Check prev. a/r values.
|
||||
if idx < count - 1:
|
||||
prev_actions_t_p_1 = batch["prev_actions"][idx + 1]
|
||||
prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1]
|
||||
# Same trajectory as for t+1 -> Should be able to match.
|
||||
if batch[SampleBatch.AGENT_INDEX][idx] == \
|
||||
batch[SampleBatch.AGENT_INDEX][idx + 1] and \
|
||||
batch[SampleBatch.EPS_ID][idx] == \
|
||||
batch[SampleBatch.EPS_ID][idx + 1]:
|
||||
assert (a_t == prev_actions_t_p_1).all()
|
||||
assert r_t == prev_rewards_t_p_1
|
||||
# Different (new) trajectory. Assume t-1 (prev-a/r) to be
|
||||
# always 0.0s. [3]=ts
|
||||
elif ts == 0:
|
||||
assert (prev_actions_t_p_1 == 0).all()
|
||||
assert prev_rewards_t_p_1 == 0.0
|
||||
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
batch,
|
||||
max_seq_len=max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=1)
|
||||
|
||||
# Check after seq-len 0-padding.
|
||||
cursor = 0
|
||||
for i, seq_len in enumerate(batch["seq_lens"]):
|
||||
state_in_0 = batch["state_in_0"][i]
|
||||
state_in_1 = batch["state_in_1"][i]
|
||||
for j in range(seq_len):
|
||||
k = cursor + j
|
||||
ts = batch["t"][k]
|
||||
obs_t = batch["obs"][k]
|
||||
a_t = batch["actions"][k]
|
||||
r_t = batch["rewards"][k]
|
||||
|
||||
# Check postprocessing outputs.
|
||||
if "postprocessed_column" in batch:
|
||||
postprocessed_col_t = batch["postprocessed_column"][k]
|
||||
assert (obs_t == postprocessed_col_t / 2.0).all()
|
||||
|
||||
# Check state-in/out and next-obs values.
|
||||
if j > 0:
|
||||
next_obs_t_m_1 = batch["new_obs"][k - 1]
|
||||
# state_out_0_t_m_1 = batch["state_out_0"][k - 1]
|
||||
# state_out_1_t_m_1 = batch["state_out_1"][k - 1]
|
||||
# Always same trajectory as for t-1.
|
||||
assert batch["unroll_id"][k - 1] == batch["unroll_id"][k]
|
||||
assert (obs_t == next_obs_t_m_1).all()
|
||||
# assert (state_in_0 == state_out_0_t_m_1).all())
|
||||
# assert (state_in_1 == state_out_1_t_m_1).all())
|
||||
# Check initial 0-internal states.
|
||||
elif ts == 0:
|
||||
assert (state_in_0 == 0.0).all()
|
||||
assert (state_in_1 == 0.0).all()
|
||||
|
||||
for j in range(seq_len, max_seq_len):
|
||||
k = cursor + j
|
||||
obs_t = batch["obs"][k]
|
||||
a_t = batch["actions"][k]
|
||||
r_t = batch["rewards"][k]
|
||||
assert (obs_t == 0.0).all()
|
||||
assert (a_t == 0.0).all()
|
||||
assert (r_t == 0.0).all()
|
||||
|
||||
cursor += max_seq_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -22,7 +22,7 @@ parser.add_argument("--stop-reward", type=float, default=150)
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(local_mode=True)
|
||||
ray.init()
|
||||
|
||||
ModelCatalog.register_custom_model(
|
||||
"bn_model", TorchBatchNormModel if args.torch else BatchNormModel)
|
||||
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop)
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -108,7 +108,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run("PPO", config=config, stop=stop)
|
||||
results = tune.run("PPO", config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -9,7 +9,6 @@ For PyTorch / TF eager mode, use the --torch and --eager flags.
|
|||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.examples.env.simple_rpg import SimpleRPG
|
||||
|
@ -21,18 +20,13 @@ parser.add_argument(
|
|||
"--framework", choices=["tf", "tfe", "torch"], default="tf")
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(local_mode=True)
|
||||
args = parser.parse_args()
|
||||
if args.framework == "torch":
|
||||
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
||||
else:
|
||||
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
|
||||
tune.run(
|
||||
"PG",
|
||||
stop={
|
||||
"timesteps_total": 1,
|
||||
},
|
||||
config={
|
||||
|
||||
config = {
|
||||
"framework": args.framework,
|
||||
"env": SimpleRPG,
|
||||
"rollout_fragment_length": 1,
|
||||
|
@ -41,5 +35,10 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": "my_model",
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
stop = {
|
||||
"timesteps_total": 1,
|
||||
}
|
||||
|
||||
tune.run("PG", config=config, stop=stop, verbose=1)
|
||||
|
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop)
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
8
rllib/examples/env/debug_counter_env.py
vendored
8
rllib/examples/env/debug_counter_env.py
vendored
|
@ -29,7 +29,7 @@ class DebugCounterEnv(gym.Env):
|
|||
class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
||||
def __init__(self, config):
|
||||
self.num_agents = config["num_agents"]
|
||||
self.p_done = config.get("p_done", 0.02)
|
||||
self.base_episode_len = config.get("base_episode_len", 103)
|
||||
# Actions are always:
|
||||
# (episodeID, envID) as floats.
|
||||
self.action_space = \
|
||||
|
@ -45,6 +45,7 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
|||
self.dones = set()
|
||||
|
||||
def reset(self):
|
||||
self.timesteps = [0] * self.num_agents
|
||||
self.dones = set()
|
||||
return {
|
||||
i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
@ -57,9 +58,8 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
|||
self.timesteps[i] += 1
|
||||
obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
|
||||
rew[i] = self.timesteps[i] % 3
|
||||
done[i] = bool(
|
||||
np.random.choice(
|
||||
[True, False], p=[self.p_done, 1.0 - self.p_done]))
|
||||
done[i] = True if self.timesteps[i] > self.base_episode_len + i \
|
||||
else False
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
done["__all__"] = len(self.dones) == self.num_agents
|
||||
|
|
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
|||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
results = tune.run("PPO", stop=stop, config=config)
|
||||
results = tune.run("PPO", stop=stop, config=config, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -73,7 +73,7 @@ class BinaryAutoregressiveDistribution(ActionDistribution):
|
|||
return a2_dist
|
||||
|
||||
@staticmethod
|
||||
def required_model_output_shape(self, model_config):
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return 16 # controls model output feature vector size
|
||||
|
||||
|
||||
|
@ -143,5 +143,5 @@ class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
|
|||
return a2_dist
|
||||
|
||||
@staticmethod
|
||||
def required_model_output_shape(self, model_config):
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return 16 # controls model output feature vector size
|
||||
|
|
|
@ -180,7 +180,7 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
|
|||
def forward(self, input_dict, state, seq_lens):
|
||||
# Set the correct train-mode for our hidden module (only important
|
||||
# b/c we have some batch-norm layers).
|
||||
self._hidden_layers.train(mode=input_dict["is_training"])
|
||||
self._hidden_layers.train(mode=input_dict.get("is_training", False))
|
||||
self._hidden_out = self._hidden_layers(input_dict["obs"])
|
||||
logits = self._logits(self._hidden_out)
|
||||
return logits, []
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
||||
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
|
@ -101,8 +103,18 @@ class TorchRNNModel(TorchRNN, nn.Module):
|
|||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
# Add state-ins to this model's view.
|
||||
for i in range(2):
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
space=Box(-1.0, 1.0, shape=(self.lstm_state_size,)))
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
||||
# View API is supported across all of RLlib.
|
||||
# Place hidden states on same device as model.
|
||||
h = [
|
||||
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
|
||||
|
|
|
@ -47,10 +47,7 @@ if __name__ == "__main__":
|
|||
"timesteps_total": args.stop_timesteps,
|
||||
}
|
||||
|
||||
results = tune.run(
|
||||
"PG",
|
||||
stop=stop,
|
||||
config={
|
||||
config = {
|
||||
"env": "multi_agent_cartpole",
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
|
@ -63,8 +60,9 @@ if __name__ == "__main__":
|
|||
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||
},
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
results = tune.run("PG", config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -52,7 +52,7 @@ if __name__ == "__main__":
|
|||
"timesteps_total": args.stop_timesteps,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop)
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
|
@ -13,8 +14,7 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.episode_id = None
|
||||
self.env_id = None
|
||||
self.state_space = Box(-1.0, 1.0, (1, ))
|
||||
|
||||
class _fake_model:
|
||||
pass
|
||||
|
@ -22,15 +22,25 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
|
|||
self.model = _fake_model()
|
||||
self.model.time_major = True
|
||||
self.model.inference_view_requirements = {
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
"env_id": ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
}
|
||||
self.training_view_requirements = dict(
|
||||
for i in range(2):
|
||||
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i), shift=-1, space=self.state_space)
|
||||
self.model.inference_view_requirements[
|
||||
"state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=self.state_space)
|
||||
|
||||
self.view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
SampleBatch.OBS, shift=1),
|
||||
|
@ -50,17 +60,23 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
|
|||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
self.episode_id = input_dict[SampleBatch.EPS_ID][0]
|
||||
self.env_id = input_dict["env_id"][0]
|
||||
# Always return (episodeID, envID)
|
||||
return [
|
||||
np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"]
|
||||
], [], {}
|
||||
ts = input_dict["t"]
|
||||
print(ts)
|
||||
# Always return [episodeID, envID] as actions.
|
||||
actions = np.array([[
|
||||
input_dict[SampleBatch.AGENT_INDEX][i],
|
||||
input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i]
|
||||
] for i, _ in enumerate(input_dict["obs"])])
|
||||
states = [
|
||||
np.array([[ts[i]] for i in range(len(input_dict["obs"]))])
|
||||
for _ in range(2)
|
||||
]
|
||||
return actions, states, {}
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0
|
||||
sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0
|
||||
return sample_batch
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
|
||||
|
||||
class AlwaysSameHeuristic(Policy):
|
||||
|
@ -10,6 +13,12 @@ class AlwaysSameHeuristic(Policy):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exploration = self._create_exploration()
|
||||
self.view_requirements.update({
|
||||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
shift=-1,
|
||||
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
||||
})
|
||||
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
|
@ -27,6 +36,9 @@ class AlwaysSameHeuristic(Policy):
|
|||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
if self.config["_use_trajectory_view_api"]:
|
||||
return state_batches[0][0], [s[0] for s in state_batches], {}
|
||||
else:
|
||||
return state_batches[0], state_batches, {}
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ def run_same_policy(args, stop):
|
|||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
results = tune.run("PG", config=config, stop=stop)
|
||||
results = tune.run("PG", config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
# Check vs 0.0 as we are playing a zero-sum game.
|
||||
|
|
|
@ -63,6 +63,8 @@ class ModelV2:
|
|||
SampleBatch.OBS: ViewRequirement(shift=0),
|
||||
}
|
||||
|
||||
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
||||
# View API is supported across all of RLlib.
|
||||
@PublicAPI
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
"""Get the initial recurrent state values for the model.
|
||||
|
|
|
@ -135,25 +135,20 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
activation_fn=None,
|
||||
initializer=torch.nn.init.xavier_uniform_)
|
||||
|
||||
self.inference_view_requirements.update(
|
||||
dict(
|
||||
**{
|
||||
SampleBatch.OBS: ViewRequirement(shift=0),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1),
|
||||
}))
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if model_config["lstm_use_prev_action_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
# Add state-ins to this model's view.
|
||||
for i in range(2):
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=Box(-1.0, 1.0, shape=(self.cell_size,)))
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
|
|
@ -5,6 +5,7 @@ import tree
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -74,7 +75,14 @@ class Policy(metaclass=ABCMeta):
|
|||
# Child classes need to add their specific requirements here (usually
|
||||
# a combination of a Model's inference_view_- and the
|
||||
# Policy's loss function-requirements.
|
||||
self.training_view_requirements = {}
|
||||
self.view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
|
@ -255,7 +263,23 @@ class Policy(metaclass=ABCMeta):
|
|||
shape like
|
||||
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# Default implementation just passes obs, prev-a/r, and states on to
|
||||
# `self.compute_actions()`.
|
||||
state_batches = [
|
||||
s.unsqueeze(0)
|
||||
if torch and isinstance(s, torch.Tensor) else np.expand_dims(s, 0)
|
||||
for k, s in input_dict.items() if k[:9] == "state_in_"
|
||||
]
|
||||
return self.compute_actions(
|
||||
input_dict[SampleBatch.OBS],
|
||||
state_batches,
|
||||
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
|
||||
info_batch=None,
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_log_likelihoods(
|
||||
|
|
|
@ -35,7 +35,6 @@ def pad_batch_to_sequences_of_same_size(
|
|||
shuffle: bool = False,
|
||||
batch_divisibility_req: int = 1,
|
||||
feature_keys: Optional[List[str]] = None,
|
||||
_use_trajectory_view_api: bool = False,
|
||||
):
|
||||
"""Applies padding to `batch` so it's choppable into same-size sequences.
|
||||
|
||||
|
@ -56,26 +55,7 @@ def pad_batch_to_sequences_of_same_size(
|
|||
feature_keys (Optional[List[str]]): An optional list of keys to apply
|
||||
sequence-chopping to. If None, use all keys in batch that are not
|
||||
"state_in/out_"-type keys.
|
||||
_use_trajectory_view_api (bool): Whether we are using the Trajectory
|
||||
View API to collect and process samples.
|
||||
"""
|
||||
if _use_trajectory_view_api:
|
||||
if batch.time_major is not None:
|
||||
batch["seq_lens"] = torch.tensor(batch.seq_lens)
|
||||
t = 0 if batch.time_major else 1
|
||||
for col in batch.data.keys():
|
||||
# Cut time-dim from states.
|
||||
if "state_" in col[:6]:
|
||||
batch[col] = batch[col][t]
|
||||
# Flatten all other data.
|
||||
else:
|
||||
# Cut time-dim at `max_seq_len`.
|
||||
if batch.time_major:
|
||||
batch[col] = batch[col][:batch.max_seq_len]
|
||||
batch[col] = batch[col].reshape((-1, ) +
|
||||
batch[col].shape[2:])
|
||||
return
|
||||
|
||||
if batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
|
||||
|
|
|
@ -11,7 +11,6 @@ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -31,8 +30,6 @@ logger = logging.getLogger(__name__)
|
|||
class TorchPolicy(Policy):
|
||||
"""Template for a PyTorch policy and loss to use with RLlib.
|
||||
|
||||
This is similar to TFPolicy, but for PyTorch.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
|
@ -114,14 +111,7 @@ class TorchPolicy(Policy):
|
|||
self.device = torch.device("cpu")
|
||||
self.model = model.to(self.device)
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.training_view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.ACTIONS: ViewRequirement(
|
||||
space=self.action_space, shift=0),
|
||||
SampleBatch.REWARDS: ViewRequirement(shift=0),
|
||||
SampleBatch.DONES: ViewRequirement(shift=0),
|
||||
},
|
||||
**self.model.inference_view_requirements)
|
||||
self.view_requirements.update(self.model.inference_view_requirements)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self.unwrapped_model = model # used to support DistributedDataParallel
|
||||
|
@ -202,9 +192,11 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
# Pack internal state inputs into (separate) list.
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_" in k[:6]
|
||||
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
||||
]
|
||||
# Calculate RNN sequence lengths.
|
||||
seq_lens = np.array([1] * len(input_dict["obs"])) \
|
||||
if state_batches else None
|
||||
|
||||
|
@ -217,7 +209,8 @@ class TorchPolicy(Policy):
|
|||
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp)
|
||||
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
||||
|
||||
return actions, state_out, extra_fetches
|
||||
return convert_to_non_torch_type((actions, state_out,
|
||||
extra_fetches))
|
||||
|
||||
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
|
||||
explore, timestep):
|
||||
|
@ -342,7 +335,6 @@ class TorchPolicy(Policy):
|
|||
max_seq_len=self.max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
_use_trajectory_view_api=self.config["_use_trajectory_view_api"],
|
||||
)
|
||||
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
@ -479,7 +471,7 @@ class TorchPolicy(Policy):
|
|||
@DeveloperAPI
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
return [
|
||||
s.cpu().detach().numpy() for s in self.model.get_initial_state()
|
||||
s.detach().cpu().numpy() for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
@override(Policy)
|
||||
|
|
|
@ -64,7 +64,7 @@ def build_torch_policy(
|
|||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
training_view_requirements_fn: Optional[Callable[[], Dict[
|
||||
view_requirements_fn: Optional[Callable[[], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None):
|
||||
"""Helper function for creating a torch policy class at runtime.
|
||||
|
@ -159,7 +159,7 @@ def build_torch_policy(
|
|||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the TorchPolicy class.
|
||||
training_view_requirements_fn (Callable[[],
|
||||
view_requirements_fn (Callable[[],
|
||||
Dict[str, ViewRequirement]]): An optional callable to retrieve
|
||||
additional train view requirements for this policy.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
|
@ -226,9 +226,8 @@ def build_torch_policy(
|
|||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
)
|
||||
|
||||
if callable(training_view_requirements_fn):
|
||||
self.training_view_requirements.update(
|
||||
training_view_requirements_fn(self))
|
||||
if callable(view_requirements_fn):
|
||||
self.view_requirements.update(view_requirements_fn(self))
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
|
|
@ -29,7 +29,8 @@ class ViewRequirement:
|
|||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
space: gym.Space = None,
|
||||
shift: Union[int, List[int]] = 0):
|
||||
shift: Union[int, List[int]] = 0,
|
||||
used_for_training: bool = True):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
Args:
|
||||
|
@ -46,8 +47,12 @@ class ViewRequirement:
|
|||
Example: For a view column "obs" in an Atari framestacking
|
||||
fashion, you can set `data_col="obs"` and
|
||||
`shift=[-3, -2, -1, 0]`.
|
||||
used_for_training (bool): Whether the data will be used for
|
||||
training. If False, the column will not be copied into the
|
||||
final train batch.
|
||||
"""
|
||||
self.data_col = data_col
|
||||
self.space = space or gym.spaces.Box(
|
||||
float("-inf"), float("inf"), shape=())
|
||||
self.shift = shift
|
||||
self.used_for_training = used_for_training
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
@ -17,14 +16,9 @@ class Repeated(gym.Space):
|
|||
"""
|
||||
|
||||
def __init__(self, child_space: gym.Space, max_len: int):
|
||||
self.np_random = np.random.RandomState()
|
||||
super().__init__()
|
||||
self.child_space = child_space
|
||||
self.max_len = max_len
|
||||
super().__init__()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random = np.random.RandomState()
|
||||
self.np_random.seed(seed)
|
||||
|
||||
def sample(self):
|
||||
return [
|
||||
|
|
|
@ -34,7 +34,7 @@ def convert_to_non_torch_type(stats):
|
|||
def mapping(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
return item.cpu().item() if len(item.size()) == 0 else \
|
||||
item.cpu().detach().numpy()
|
||||
item.detach().cpu().numpy()
|
||||
else:
|
||||
return item
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue