diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 88e91bf82..52423f0ce 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -37,6 +37,9 @@ DEFAULT_CONFIG = with_common_config({ # Workers sample async. Note that this increases the effective # rollout_fragment_length by up to 5x due to async buffering of batches. "sample_async": True, + # Switch on Trajectory View API for A2/3C by default. + # NOTE: Only supported for PyTorch so far. + "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 8e29149ea..c91874b03 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -1,8 +1,13 @@ +import gym +from typing import Dict + import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() @@ -82,6 +87,42 @@ class ValueNetworkMixin: return self.model.value_function()[0] +def view_requirements_fn(policy: Policy) -> Dict[str, ViewRequirement]: + """Function defining the view requirements for training/postprocessing. + + These go on top of the Policy's Model's own view requirements used for + the action computing forward passes. + + Args: + policy (Policy): The Policy that requires the returned + ViewRequirements. + + Returns: + Dict[str, ViewRequirement]: The Policy's view requirements. + """ + ret = { + # Next obs are needed for PPO postprocessing, but not in loss. + SampleBatch.NEXT_OBS: ViewRequirement( + SampleBatch.OBS, shift=1, used_for_training=False), + # Created during postprocessing. + Postprocessing.ADVANTAGES: ViewRequirement(shift=0), + Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0), + # Needed for PPO's loss function. + SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0), + SampleBatch.ACTION_LOGP: ViewRequirement(shift=0), + SampleBatch.VF_PREDS: ViewRequirement(shift=0), + } + # If policy is recurrent, have to add state_out for PPO postprocessing + # (calculating GAE from next-obs and last state-out). + if policy.is_recurrent(): + init_state = policy.get_initial_state() + for i, s in enumerate(init_state): + ret["state_out_{}".format(i)] = ViewRequirement( + space=gym.spaces.Box(-1.0, 1.0, shape=(s.shape[0], )), + used_for_training=False) + return ret + + A3CTorchPolicy = build_torch_policy( name="A3CTorchPolicy", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, @@ -91,4 +132,6 @@ A3CTorchPolicy = build_torch_policy( extra_action_out_fn=model_value_predictions, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=torch_optimizer, - mixins=[ValueNetworkMixin]) + mixins=[ValueNetworkMixin], + view_requirements_fn=view_requirements_fn, +) diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index d4e774875..4d7950459 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -8,6 +8,7 @@ See `pg_[tf|torch]_policy.py` for the definition of the policy loss. Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg """ +import logging from typing import Optional, Type from ray.rllib.agents.trainer import with_common_config @@ -17,6 +18,8 @@ from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy from ray.rllib.policy.policy import Policy from ray.rllib.utils.typing import TrainerConfigDict +logger = logging.getLogger(__name__) + # yapf: disable # __sphinx_doc_begin__ @@ -27,6 +30,9 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 0, # Learning rate. "lr": 0.0004, + # Switch on Trajectory View API for PG by default. + # NOTE: Only supported for PyTorch so far. + "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index 93cb2f4ac..a83be9bc1 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -5,6 +5,7 @@ PyTorch policy class used for PG. from typing import Dict, List, Type, Union import ray +from ray.rllib.agents.a3c.a3c_torch_policy import view_requirements_fn from ray.rllib.agents.pg.utils import post_process_advantages from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper @@ -77,4 +78,6 @@ PGTorchPolicy = build_torch_policy( get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, loss_fn=pg_torch_loss, stats_fn=pg_loss_stats, - postprocess_fn=post_process_advantages) + postprocess_fn=post_process_advantages, + view_requirements_fn=view_requirements_fn, +) diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index 77dd1768b..a1f16b8d9 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -72,6 +72,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs( "truncate_episodes": True, # This is auto set based on sample batch size. "train_batch_size": -1, + # Trajectory View API not supported yet for DD-PPO. + "_use_trajectory_view_api": False, }, _allow_unknown_configs=True, ) diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index d7eee985d..949c7be63 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -89,6 +89,9 @@ DEFAULT_CONFIG = with_common_config({ # Whether to fake GPUs (using CPUs). # Set this to True for debugging on non-GPU machines (set `num_gpus` > 0). "_fake_gpus": False, + # Switch on Trajectory View API for PPO by default. + # NOTE: Only supported for PyTorch so far. + "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ @@ -126,7 +129,8 @@ def validate_config(config: TrainerConfigDict) -> None: if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]: raise ValueError( "Episode truncation is not supported without a value " - "function. Consider setting batch_mode=complete_episodes.") + "function (to estimate the return at the end of the truncated " + "trajectory). Consider setting batch_mode=complete_episodes.") # Multi-gpu not supported for PyTorch and tf-eager. if config["framework"] in ["tf2", "tfe", "torch"]: diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index 957878ed4..b556c36d7 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -7,7 +7,8 @@ import numpy as np from typing import Dict, List, Type, Union import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping +from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping, \ + view_requirements_fn from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ setup_config from ray.rllib.evaluation.postprocessing import Postprocessing @@ -18,7 +19,6 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy -from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \ explained_variance, sequence_mask @@ -255,34 +255,6 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) -def training_view_requirements_fn( - policy: Policy) -> Dict[str, ViewRequirement]: - """Function defining the view requirements for training the policy. - - These go on top of the Policy's Model's own view requirements used for - action computing forward passes. - - Args: - policy (Policy): The Policy that requires the returned - ViewRequirements. - - Returns: - Dict[str, ViewRequirement]: The Policy's view requirements. - """ - return { - # Next obs are needed for PPO postprocessing. - SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1), - # VF preds are needed for the loss. - SampleBatch.VF_PREDS: ViewRequirement(shift=0), - # Needed for postprocessing. - SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0), - SampleBatch.ACTION_LOGP: ViewRequirement(shift=0), - # Created during postprocessing. - Postprocessing.ADVANTAGES: ViewRequirement(shift=0), - Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0), - } - - # Build a child class of `TorchPolicy`, given the custom functions defined # above. PPOTorchPolicy = build_torch_policy( @@ -299,5 +271,5 @@ PPOTorchPolicy = build_torch_policy( LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin ], - training_view_requirements_fn=training_view_requirements_fn, + view_requirements_fn=view_requirements_fn, ) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 1c48e2dfb..e03c9514d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1047,9 +1047,10 @@ class Trainer(Trainable): def _validate_config(config: PartialTrainerConfigDict): if config.get("_use_trajectory_view_api") and \ config.get("framework") != "torch": - raise ValueError( + logger.info( "`_use_trajectory_view_api` only supported for PyTorch so " - "far!") + "far! Will run w/o.") + config["_use_trajectory_view_api"] = False elif not config.get("_use_trajectory_view_api") and \ config.get("model", {}).get("_time_major"): raise ValueError("`model._time_major` only supported " diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 6c9544833..bc949ccfa 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -183,7 +183,8 @@ class BaseEnv: reset the entire Env (i.e. all sub-envs). Returns: - obs (dict|None): Resetted observation or None if not supported. + Optional[MultiAgentDict]: Resetted (multi-agent) observation dict + or None if reset is not supported. """ return None diff --git a/rllib/evaluation/collectors/__init__.py b/rllib/evaluation/collectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rllib/evaluation/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py similarity index 67% rename from rllib/evaluation/sample_collector.py rename to rllib/evaluation/collectors/sample_collector.py index babf65e9d..7d6f00b52 100644 --- a/rllib/evaluation/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -1,9 +1,10 @@ from abc import abstractmethod, ABCMeta import logging -from typing import Dict, Optional +from typing import Dict, Union from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.utils.typing import AgentID, EpisodeID, PolicyID, \ +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ TensorType logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ class _SampleCollector(metaclass=ABCMeta): """ @abstractmethod - def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, + def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, init_obs: TensorType) -> None: """Adds an initial obs (after reset) to this collector. @@ -41,10 +42,11 @@ class _SampleCollector(metaclass=ABCMeta): called for that same agent/episode-pair. Args: - episode_id (EpisodeID): Unique id for the episode we are adding - values for. + episode (MultiAgentEpisode): The MultiAgentEpisode, for which we + are adding an Agent's initial observation. agent_id (AgentID): Unique id for the agent we are adding values for. + env_id (EnvID): The environment index (in a vectorized setup). policy_id (PolicyID): Unique id for policy controlling the agent. init_obs (TensorType): Initial observation (after env.reset()). @@ -52,7 +54,7 @@ class _SampleCollector(metaclass=ABCMeta): >>> obs = env.reset() >>> collector.add_init_obs(12345, 0, "pol0", obs) >>> obs, r, done, info = env.step(action) - >>> collector.add_action_reward_next_obs(12345, 0, "pol0", { + >>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, { ... "action": action, "obs": obs, "reward": r, "done": done ... }) """ @@ -60,7 +62,8 @@ class _SampleCollector(metaclass=ABCMeta): @abstractmethod def add_action_reward_next_obs(self, episode_id: EpisodeID, - agent_id: AgentID, policy_id: PolicyID, + agent_id: AgentID, env_id: EnvID, + policy_id: PolicyID, agent_done: bool, values: Dict[str, TensorType]) -> None: """Add the given dictionary (row) of values to this collector. @@ -74,7 +77,10 @@ class _SampleCollector(metaclass=ABCMeta): values for. agent_id (AgentID): Unique id for the agent we are adding values for. + env_id (EnvID): The environment index (in a vectorized setup). policy_id (PolicyID): Unique id for policy controlling the agent. + agent_done (bool): Whether the given agent is done with its + trajectory (the multi-agent episode may still be ongoing). values (Dict[str, TensorType]): Row of values to add for this agent. This row must contain the keys SampleBatch.ACTION, REWARD, NEW_OBS, and DONE. @@ -83,12 +89,22 @@ class _SampleCollector(metaclass=ABCMeta): >>> obs = env.reset() >>> collector.add_init_obs(12345, 0, "pol0", obs) >>> obs, r, done, info = env.step(action) - >>> collector.add_action_reward_next_obs(12345, 0, "pol0", { + >>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, { ... "action": action, "obs": obs, "reward": r, "done": done ... }) """ raise NotImplementedError + @abstractmethod + def episode_step(self, episode_id: EpisodeID) -> None: + """Increases the episode step counter (across all agents) by one. + + Args: + episode_id (EpisodeID): Unique id for the episode we are stepping + through (across all agents in that episode). + """ + raise NotImplementedError + @abstractmethod def total_env_steps(self) -> int: """Returns total number of steps taken in the env (sum of all agents). @@ -126,19 +142,11 @@ class _SampleCollector(metaclass=ABCMeta): raise NotImplementedError @abstractmethod - def has_non_postprocessed_data(self) -> bool: - """Returns whether there is pending, unprocessed data. - - Returns: - bool: True if there is at least some data that has not been - postprocessed yet. - """ - raise NotImplementedError - - @abstractmethod - def postprocess_trajectories_so_far( - self, episode: Optional[MultiAgentEpisode] = None) -> None: - """Apply postprocessing to unprocessed data (in one or all episodes). + def postprocess_episode(self, + episode: MultiAgentEpisode, + is_done: bool = False, + check_dones: bool = False) -> None: + """Postprocesses all agents' trajectories in a given episode. Generates (single-trajectory) SampleBatches for all Policies/Agents and calls Policy.postprocess_trajectory on each of these. Postprocessing @@ -148,38 +156,46 @@ class _SampleCollector(metaclass=ABCMeta): correctly added to the buffers. Args: - episode (Optional[MultiAgentEpisode]): The Episode object for which - to post-process data. If not provided, postprocess data for all - episodes. + episode (MultiAgentEpisode): The Episode object for which + to post-process data. + is_done (bool): Whether the given episode is actually terminated + (all agents are done). + check_dones (bool): Whether we need to check that all agents' + trajectories have dones=True at the end. """ raise NotImplementedError @abstractmethod - def check_missing_dones(self, episode_id: EpisodeID) -> None: - """Checks whether given episode is properly terminated with done=True. - - This applies to all agents in the episode. + def build_multi_agent_batch(self, env_steps: int) -> \ + Union[MultiAgentBatch, SampleBatch]: + """Builds a MultiAgentBatch of size=env_steps from the collected data. Args: - episode_id (EpisodeID): The episode ID to check for proper - termination. + env_steps (int): The sum of all env-steps (across all agents) taken + so far. - Raises: - ValueError: If `episode` has no done=True at the end. + Returns: + Union[MultiAgentBatch, SampleBatch]: Returns the accumulated + sample batches for each policy inside one MultiAgentBatch + object (or a simple SampleBatch if only one policy). """ raise NotImplementedError @abstractmethod - def get_multi_agent_batch_and_reset(self): - """Returns the accumulated sample batches for each policy. + def try_build_truncated_episode_multi_agent_batch(self) -> \ + Union[MultiAgentBatch, SampleBatch, None]: + """Tries to build an MA-batch, if `rollout_fragment_length` is reached. - Any unprocessed rows will be first postprocessed with a policy - postprocessor. The internal state of this builder will be reset to - start the next batch. + Any unprocessed data will be first postprocessed with a policy + postprocessor. This is usually called to collect samples for policy training. + If not enough data has been collected yet (`rollout_fragment_length`), + returns None. Returns: - MultiAgentBatch: Returns the accumulated sample batches for each - policy inside one MultiAgentBatch object. + Union[MultiAgentBatch, SampleBatch, None]: Returns the accumulated + sample batches for each policy inside one MultiAgentBatch + object (or a simple SampleBatch if only one policy) or None + if `self.rollout_fragment_length` has not been reached yet. """ raise NotImplementedError diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py new file mode 100644 index 000000000..15a5b351f --- /dev/null +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -0,0 +1,589 @@ +import collections +import logging +import numpy as np +from typing import List, Any, Dict, Tuple, TYPE_CHECKING, Union + +from ray.rllib.env.base_env import _DUMMY_AGENT_ID +from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector +from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \ + TensorType +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.torch_ops import convert_to_non_torch_type +from ray.util.debug import log_once + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + +if TYPE_CHECKING: + from ray.rllib.agents.callbacks import DefaultCallbacks + +logger = logging.getLogger(__name__) + + +def to_float_np_array(v: List[Any]) -> np.ndarray: + if torch.is_tensor(v[0]): + raise ValueError + v = convert_to_non_torch_type(v) + arr = np.array(v) + if arr.dtype == np.float64: + return arr.astype(np.float32) # save some memory + return arr + + +class _AgentCollector: + """Collects samples for one agent in one trajectory (episode). + + The agent may be part of a multi-agent environment. Samples are stored in + lists including some possible automatic "shift" buffer at the beginning to + be able to save memory when storing things like NEXT_OBS, PREV_REWARDS, + etc.., which are specified using the trajectory view API. + """ + + _next_unroll_id = 0 # disambiguates unrolls within a single episode + + def __init__(self, shift_before: int = 0): + self.shift_before = max(shift_before, 1) + self.buffers: Dict[str, List] = {} + # The simple timestep count for this agent. Gets increased by one + # each time a (non-initial!) observation is added. + self.count = 0 + + def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, + env_id: EnvID, init_obs: TensorType, + view_requirements: Dict[str, ViewRequirement]) -> None: + """Adds an initial observation (after reset) to the Agent's trajectory. + + Args: + episode_id (EpisodeID): Unique ID for the episode we are adding the + initial observation for. + agent_id (AgentID): Unique ID for the agent we are adding the + initial observation for. + env_id (EnvID): The environment index (in a vectorized setup). + init_obs (TensorType): The initial observation tensor (after + `env.reset()`). + view_requirements (Dict[str, ViewRequirements]) + """ + if SampleBatch.OBS not in self.buffers: + self._build_buffers( + single_row={ + SampleBatch.OBS: init_obs, + SampleBatch.EPS_ID: episode_id, + SampleBatch.AGENT_INDEX: agent_id, + "env_id": env_id, + }) + self.buffers[SampleBatch.OBS].append(init_obs) + + def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ + None: + """Adds the given dictionary (row) of values to the Agent's trajectory. + + Args: + values (Dict[str, TensorType]): Data dict (interpreted as a single + row) to be added to buffer. Must contain keys: + SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. + """ + + assert SampleBatch.OBS not in values + values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] + del values[SampleBatch.NEXT_OBS] + + for k, v in values.items(): + if k not in self.buffers: + self._build_buffers(single_row=values) + self.buffers[k].append(v) + self.count += 1 + + def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ + SampleBatch: + """Builds a SampleBatch from the thus-far collected agent data. + + If the episode/trajectory has no DONE=True at the end, will copy + the necessary n timesteps at the end of the trajectory back to the + beginning of the buffers and wait for new samples coming in. + SampleBatches created by this method will be ready for postprocessing + by a Policy. + + Args: + view_requirements (Dict[str, ViewRequirement]: The view + requirements dict needed to build the SampleBatch from the raw + buffers (which may have data shifts as well as mappings from + view-col to data-col in them). + Returns: + SampleBatch: The built SampleBatch for this agent, ready to go into + postprocessing. + """ + + # TODO: measure performance gains when using a UsageTrackingDict + # instead of a SampleBatch for postprocessing (this would eliminate + # copies (for creating this SampleBatch) of many unused columns for + # no reason (not used by postprocessor)). + + batch_data = {} + np_data = {} + for view_col, view_req in view_requirements.items(): + # Create the batch of data from the different buffers. + data_col = view_req.data_col or view_col + # Some columns don't exist yet (get created during postprocessing). + # -> skip. + if data_col not in self.buffers: + continue + shift = view_req.shift - \ + (1 if data_col == SampleBatch.OBS else 0) + if data_col not in np_data: + np_data[data_col] = to_float_np_array(self.buffers[data_col]) + if shift == 0: + batch_data[view_col] = np_data[data_col][self.shift_before:] + else: + batch_data[view_col] = np_data[data_col][self.shift_before + + shift:shift] + batch = SampleBatch(batch_data) + + if SampleBatch.UNROLL_ID not in batch.data: + batch.data[SampleBatch.UNROLL_ID] = np.repeat( + _AgentCollector._next_unroll_id, batch.count) + _AgentCollector._next_unroll_id += 1 + + # This trajectory is continuing -> Copy data at the end (in the size of + # self.shift_before) to the beginning of buffers and erase everything + # else. + if not self.buffers[SampleBatch.DONES][-1]: + # Copy data to beginning of buffer and cut lists. + if self.shift_before > 0: + for k, data in self.buffers.items(): + self.buffers[k] = data[-self.shift_before:] + self.count = 0 + + return batch + + def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: + """Builds the buffers for sample collection, given an example data row. + + Args: + single_row (Dict[str, TensorType]): A single row (keys=column + names) of data to base the buffers on. + """ + for col, data in single_row.items(): + if col in self.buffers: + continue + shift = self.shift_before - (1 if col == SampleBatch.OBS else 0) + # Python primitive. + if isinstance(data, (int, float, bool, str)): + self.buffers[col] = [0 for _ in range(shift)] + # np.ndarray, torch.Tensor, or tf.Tensor. + else: + shape = data.shape + dtype = data.dtype + if torch and isinstance(data, torch.Tensor): + self.buffers[col] = \ + [torch.zeros(shape, dtype=dtype, device=data.device) + for _ in range(shift)] + elif tf and isinstance(data, tf.Tensor): + self.buffers[col] = \ + [tf.zeros(shape=shape, dtype=dtype) + for _ in range(shift)] + else: + self.buffers[col] = \ + [np.zeros(shape=shape, dtype=dtype) + for _ in range(shift)] + + +class _PolicyCollector: + """Collects already postprocessed (single agent) samples for one policy. + + Samples come in through already postprocessed SampleBatches, which + contain single episode/trajectory data for a single agent and are then + appended to this policy's buffers. + """ + + def __init__(self): + """Initializes a _PolicyCollector instance.""" + + self.buffers: Dict[str, List] = collections.defaultdict(list) + # The total timestep count for all agents that use this policy. + # NOTE: This is not an env-step count (across n agents). AgentA and + # agentB, both using this policy, acting in the same episode and both + # doing n steps would increase the count by 2*n. + self.count = 0 + + def add_postprocessed_batch_for_training( + self, batch: SampleBatch, + view_requirements: Dict[str, ViewRequirement]) -> None: + """Adds a postprocessed SampleBatch (single agent) to our buffers. + + Args: + batch (SampleBatch): A single agent (one trajectory) SampleBatch + to be added to the Policy's buffers. + view_requirements (Dict[str, ViewRequirement]: The view + requirements for the policy. This is so we know, whether a + view-column needs to be copied at all (not needed for + training). + """ + for view_col, data in batch.items(): + # Skip columns that are not used for training. + if view_col in view_requirements and \ + not view_requirements[view_col].used_for_training: + continue + self.buffers[view_col].extend(data) + # Add the agent's trajectory length to our count. + self.count += batch.count + + def build(self): + """Builds a SampleBatch for this policy from the collected data. + + Also resets all buffers for further sample collection for this policy. + + Returns: + SampleBatch: The SampleBatch with all thus-far collected data for + this policy. + """ + # Create batch from our buffers. + batch = SampleBatch(self.buffers) + assert SampleBatch.UNROLL_ID in batch.data + # Clear buffers for future samples. + self.buffers.clear() + # Reset count to 0. + self.count = 0 + return batch + + +class _SimpleListCollector(_SampleCollector): + """Util to build SampleBatches for each policy in a multi-agent env. + + Input data is per-agent, while output data is per-policy. There is an M:N + mapping between agents and policies. We retain one local batch builder + per agent. When an agent is done, then its local batch is appended into the + corresponding policy batch for the agent's policy. + """ + + def __init__(self, + policy_map: Dict[PolicyID, Policy], + clip_rewards: Union[bool, float], + callbacks: "DefaultCallbacks", + multiple_episodes_in_batch: bool = True, + rollout_fragment_length: int = 200): + """Initializes a _SimpleListCollector instance. + + Args: + policy_map (Dict[str, Policy]): Maps policy ids to policy + instances. + clip_rewards (Union[bool, float]): Whether to clip rewards before + postprocessing (at +/-1.0) or the actual value to +/- clip. + callbacks (DefaultCallbacks): RLlib callbacks. + """ + + self.policy_map = policy_map + self.clip_rewards = clip_rewards + self.callbacks = callbacks + self.multiple_episodes_in_batch = multiple_episodes_in_batch + self.rollout_fragment_length = rollout_fragment_length + self.large_batch_threshold: int = max( + 1000, rollout_fragment_length * + 10) if rollout_fragment_length != float("inf") else 5000 + + # Build each Policies' single collector. + self.policy_collectors = { + pid: _PolicyCollector() + for pid in policy_map.keys() + } + self.policy_collectors_env_steps = 0 + # Whenever we observe a new episode+agent, add a new + # _SingleTrajectoryCollector. + self.agent_collectors: Dict[Tuple[EpisodeID, AgentID], + _AgentCollector] = {} + # Internal agent-key-to-policy map. + self.agent_key_to_policy = {} + + # Agents to collect data from for the next forward pass (per policy). + self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()} + self.forward_pass_size = {pid: 0 for pid in policy_map.keys()} + + # Maps episode ID to _EpisodeRecord objects. + self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int) + self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {} + + @override(_SampleCollector) + def episode_step(self, episode_id: EpisodeID) -> None: + self.episode_steps[episode_id] += 1 + + env_steps = \ + self.policy_collectors_env_steps + self.episode_steps[episode_id] + if (env_steps > self.large_batch_threshold + and log_once("large_batch_warning")): + logger.warning( + "More than {} observations for {} env steps ".format( + env_steps, env_steps) + + "are buffered in the sampler. If this is more than you " + "expected, check that that you set a horizon on your " + "environment correctly and that it terminates at some point. " + "Note: In multi-agent environments, `rollout_fragment_length` " + "sets the batch size based on (across-agents) environment " + "steps, not the steps of individual agents, which can result " + "in unexpectedly large batches." + + ("Also, you may be in evaluation waiting for your Env to " + "terminate (batch_mode=`complete_episodes`). Make sure it " + "does at some point." + if not self.multiple_episodes_in_batch else "")) + + @override(_SampleCollector) + def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, + env_id: EnvID, policy_id: PolicyID, + init_obs: TensorType) -> None: + # Make sure our mappings are up to date. + agent_key = (episode.episode_id, agent_id) + if agent_key not in self.agent_key_to_policy: + self.agent_key_to_policy[agent_key] = policy_id + else: + assert self.agent_key_to_policy[agent_key] == policy_id + policy = self.policy_map[policy_id] + view_reqs = policy.model.inference_view_requirements if \ + hasattr(policy, "model") else policy.view_requirements + + # Add initial obs to Trajectory. + assert agent_key not in self.agent_collectors + # TODO: determine exact shift-before based on the view-req shifts. + self.agent_collectors[agent_key] = _AgentCollector() + self.agent_collectors[agent_key].add_init_obs( + episode_id=episode.episode_id, + agent_id=agent_id, + env_id=env_id, + init_obs=init_obs, + view_requirements=view_reqs) + + self.episodes[episode.episode_id] = episode + + self._add_to_next_inference_call(agent_key, env_id) + + @override(_SampleCollector) + def add_action_reward_next_obs(self, episode_id: EpisodeID, + agent_id: AgentID, env_id: EnvID, + policy_id: PolicyID, agent_done: bool, + values: Dict[str, TensorType]) -> None: + # Make sure, episode/agent already has some (at least init) data. + agent_key = (episode_id, agent_id) + assert self.agent_key_to_policy[agent_key] == policy_id + assert agent_key in self.agent_collectors + + # Include the current agent id for multi-agent algorithms. + if agent_id != _DUMMY_AGENT_ID: + values["agent_id"] = agent_id + + # Add action/reward/next-obs (and other data) to Trajectory. + self.agent_collectors[agent_key].add_action_reward_next_obs(values) + + if not agent_done: + self._add_to_next_inference_call(agent_key, env_id) + + @override(_SampleCollector) + def total_env_steps(self) -> int: + return sum(a.count for a in self.agent_collectors.values()) + + @override(_SampleCollector) + def get_inference_input_dict(self, policy_id: PolicyID) -> \ + Dict[str, TensorType]: + policy = self.policy_map[policy_id] + keys = self.forward_pass_agent_keys[policy_id] + buffers = {k: self.agent_collectors[k].buffers for k in keys} + view_reqs = policy.model.inference_view_requirements if \ + hasattr(policy, "model") else policy.view_requirements + + input_dict = {} + for view_col, view_req in view_reqs.items(): + # Create the batch of data from the different buffers. + data_col = view_req.data_col or view_col + time_indices = \ + view_req.shift - ( + 1 if data_col in [SampleBatch.OBS, "t", "env_id", + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX] else 0) + data_list = [] + for k in keys: + if data_col not in buffers[k]: + self.agent_collectors[k]._build_buffers({ + data_col: view_req.space.sample() + }) + data_list.append(buffers[k][data_col][time_indices]) + input_dict[view_col] = np.array(data_list) + + self._reset_inference_calls(policy_id) + + return input_dict + + @override(_SampleCollector) + def postprocess_episode(self, + episode: MultiAgentEpisode, + is_done: bool = False, + check_dones: bool = False) -> None: + episode_id = episode.episode_id + + # TODO: (sven) Once we implement multi-agent communication channels, + # we have to resolve the restriction of only sending other agent + # batches from the same policy to the postprocess methods. + # Build SampleBatches for the given episode. + pre_batches = {} + for (eps_id, agent_id), collector in self.agent_collectors.items(): + # Build only if there is data and agent is part of given episode. + if collector.count == 0 or eps_id != episode_id: + continue + policy = self.policy_map[self.agent_key_to_policy[(eps_id, + agent_id)]] + pre_batch = collector.build(policy.view_requirements) + pre_batches[agent_id] = (policy, pre_batch) + + # Apply postprocessor. + post_batches = {} + if self.clip_rewards is True: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.sign(pre_batch["rewards"]) + elif self.clip_rewards: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.clip( + pre_batch["rewards"], + a_min=-self.clip_rewards, + a_max=self.clip_rewards) + + for agent_id, (_, pre_batch) in pre_batches.items(): + # Entire episode is said to be done. + if is_done: + # Error if no DONE at end of this agent's trajectory. + if check_dones and not pre_batch[SampleBatch.DONES][-1]: + raise ValueError( + "Episode {} terminated for all agents, but we still " + "don't have a last observation for agent {} (policy " + "{}). ".format( + episode_id, agent_id, self.agent_key_to_policy[( + episode_id, agent_id)]) + + "Please ensure that you include the last observations " + "of all live agents when setting done[__all__] to " + "True. Alternatively, set no_done_at_end=True to " + "allow this.") + # If (only this?) agent is done, erase its buffer entirely. + if pre_batch[SampleBatch.DONES][-1]: + del self.agent_collectors[(episode_id, agent_id)] + + other_batches = pre_batches.copy() + del other_batches[agent_id] + policy = self.policy_map[self.agent_key_to_policy[(episode_id, + agent_id)]] + if any(pre_batch["dones"][:-1]) or len(set( + pre_batch["eps_id"])) > 1: + raise ValueError( + "Batches sent to postprocessing must only contain steps " + "from a single trajectory.", pre_batch) + # Call the Policy's Exploration's postprocess method. + post_batches[agent_id] = pre_batch + if getattr(policy, "exploration", None) is not None: + policy.exploration.postprocess_trajectory( + policy, post_batches[agent_id], + getattr(policy, "_sess", None)) + post_batches[agent_id] = policy.postprocess_trajectory( + post_batches[agent_id], other_batches, episode) + + if log_once("after_post"): + logger.info( + "Trajectory fragment after postprocess_trajectory():\n\n{}\n". + format(summarize(post_batches))) + + # Append into policy batches and reset. + from ray.rllib.evaluation.rollout_worker import get_global_worker + for agent_id, post_batch in sorted(post_batches.items()): + pid = self.agent_key_to_policy[(episode_id, agent_id)] + policy = self.policy_map[pid] + self.callbacks.on_postprocess_trajectory( + worker=get_global_worker(), + episode=episode, + agent_id=agent_id, + policy_id=pid, + policies=self.policy_map, + postprocessed_batch=post_batch, + original_batches=pre_batches) + # Add the postprocessed SampleBatch to the policy collectors for + # training. + self.policy_collectors[pid].add_postprocessed_batch_for_training( + post_batch, policy.view_requirements) + + env_steps = self.episode_steps[episode_id] + self.policy_collectors_env_steps += env_steps + + if is_done: + del self.episode_steps[episode_id] + del self.episodes[episode_id] + else: + self.episode_steps[episode_id] = 0 + + @override(_SampleCollector) + def build_multi_agent_batch(self, env_steps: int) -> \ + Union[MultiAgentBatch, SampleBatch]: + ma_batch = MultiAgentBatch.wrap_as_needed( + { + pid: collector.build() + for pid, collector in self.policy_collectors.items() + if collector.count > 0 + }, + env_steps=env_steps) + self.policy_collectors_env_steps = 0 + return ma_batch + + @override(_SampleCollector) + def try_build_truncated_episode_multi_agent_batch(self) -> \ + Union[MultiAgentBatch, SampleBatch, None]: + # Have something to loop through, even if there are currently no + # ongoing episodes. + episode_steps = self.episode_steps or {"_fake_id": 0} + # Loop through ongoing episodes and see whether their length plus + # what's already in the policy collectors reaches the fragment-len. + for episode_id, count in episode_steps.items(): + env_steps = self.policy_collectors_env_steps + count + # Reached the fragment-len -> We should build an MA-Batch. + if env_steps >= self.rollout_fragment_length: + # If we reached the fragment-len only because of `episode_id` + # (still ongoing) -> postprocess `episode_id` first. + if self.policy_collectors_env_steps < \ + self.rollout_fragment_length: + self.postprocess_episode( + self.episodes[episode_id], is_done=False) + # Otherwise, create MA-batch only from what's already in our + # policy buffers (do not include `episode_id`'s data). + else: + env_steps = self.policy_collectors_env_steps + # Build the MA-batch and return. + ma_batch = self.build_multi_agent_batch(env_steps=env_steps) + return ma_batch + return None + + def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID], + env_id: EnvID) -> None: + """Adds an Agent key (episode+agent IDs) to the next inference call. + + This makes sure that the agent's current data (in the trajectory) is + used for generating the next input_dict for a + `Policy.compute_actions()` call. + + Args: + agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across + vectorized environments). + env_id (EnvID): The environment index (in a vectorized setup). + """ + policy_id = self.agent_key_to_policy[agent_key] + idx = self.forward_pass_size[policy_id] + if idx == 0: + self.forward_pass_agent_keys[policy_id].clear() + self.forward_pass_agent_keys[policy_id].append(agent_key) + self.forward_pass_size[policy_id] += 1 + + def _reset_inference_calls(self, policy_id: PolicyID) -> None: + """Resets internal inference input-dict registries. + + Calling `self.get_inference_input_dict()` after this method is called + would return an empty input-dict. + + Args: + policy_id (PolicyID): The policy ID for which to reset the + inference pointers. + """ + self.forward_pass_size[policy_id] = 0 diff --git a/rllib/evaluation/multi_agent_sample_collector.py b/rllib/evaluation/multi_agent_sample_collector.py deleted file mode 100644 index 7c21b0bec..000000000 --- a/rllib/evaluation/multi_agent_sample_collector.py +++ /dev/null @@ -1,249 +0,0 @@ -import logging -from typing import Dict, Optional, TYPE_CHECKING - -from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.evaluation.per_policy_sample_collector import \ - _PerPolicySampleCollector -from ray.rllib.evaluation.sample_collector import _SampleCollector -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import MultiAgentBatch -from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import override -from ray.rllib.utils.debug import summarize -from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ - TensorType -from ray.util.debug import log_once - -if TYPE_CHECKING: - from ray.rllib.agents.callbacks import DefaultCallbacks - -logger = logging.getLogger(__name__) - - -class _MultiAgentSampleCollector(_SampleCollector): - """Builds SampleBatches for each policy (and agent) in a multi-agent env. - - Note: This is an experimental class only used when - `config._use_trajectory_view_api` = True. - Once `_use_trajectory_view_api` becomes the default in configs: - This class will deprecate the `SampleBatchBuilder` class. - - Input data is collected in central per-policy buffers, which - efficiently pre-allocate memory (over n timesteps) and re-use the same - memory even for succeeding agents and episodes. - Input_dicts for action computations, SampleBatches for postprocessing, and - train_batch dicts are - if possible - created from the central per-policy - buffers via views to avoid copying of data). - """ - - def __init__( - self, - policy_map: Dict[PolicyID, Policy], - callbacks: "DefaultCallbacks", - # TODO: (sven) make `num_agents` flexibly grow in size. - num_agents: int = 100, - num_timesteps=None, - time_major: Optional[bool] = False): - """Initializes a _MultiAgentSampleCollector object. - - Args: - policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy - instances. - callbacks (DefaultCallbacks): RLlib callbacks (configured in the - Trainer config dict). Used for trajectory postprocessing event. - num_agents (int): The max number of agent slots to pre-allocate - in the buffer. - num_timesteps (int): The max number of timesteps to pre-allocate - in the buffer. - time_major (Optional[bool]): Whether to preallocate buffers and - collect samples in time-major fashion (TxBx...). - """ - - self.policy_map = policy_map - self.callbacks = callbacks - if num_agents == float("inf") or num_agents is None: - num_agents = 1000 - self.num_agents = int(num_agents) - - # Collect SampleBatches per-policy in _PerPolicySampleCollectors. - self.policy_sample_collectors = {} - for pid, policy in policy_map.items(): - # Figure out max-shifts (before and after). - view_reqs = policy.training_view_requirements - max_shift_before = 0 - max_shift_after = 0 - for vr in view_reqs.values(): - shift = force_list(vr.shift) - if max_shift_before > shift[0]: - max_shift_before = shift[0] - if max_shift_after < shift[-1]: - max_shift_after = shift[-1] - # Figure out num_timesteps and num_agents. - kwargs = {"time_major": time_major} - if policy.is_recurrent(): - kwargs["num_timesteps"] = \ - policy.config["model"]["max_seq_len"] - kwargs["time_major"] = True - elif num_timesteps is not None: - kwargs["num_timesteps"] = num_timesteps - - self.policy_sample_collectors[pid] = _PerPolicySampleCollector( - num_agents=self.num_agents, - shift_before=-max_shift_before, - shift_after=max_shift_after, - **kwargs) - - # Internal agent-to-policy map. - self.agent_to_policy = {} - # Number of "inference" steps taken in the environment. - # Regardless of the number of agents involved in each of these steps. - self.count = 0 - - @override(_SampleCollector) - def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, - env_id: EnvID, policy_id: PolicyID, - obs: TensorType) -> None: - # Make sure our mappings are up to date. - if agent_id not in self.agent_to_policy: - self.agent_to_policy[agent_id] = policy_id - else: - assert self.agent_to_policy[agent_id] == policy_id - - # Add initial obs to Trajectory. - self.policy_sample_collectors[policy_id].add_init_obs( - episode_id, agent_id, env_id, chunk_num=0, init_obs=obs) - - @override(_SampleCollector) - def add_action_reward_next_obs(self, episode_id: EpisodeID, - agent_id: AgentID, env_id: EnvID, - policy_id: PolicyID, agent_done: bool, - values: Dict[str, TensorType]) -> None: - assert policy_id in self.policy_sample_collectors - - # Make sure our mappings are up to date. - if agent_id not in self.agent_to_policy: - self.agent_to_policy[agent_id] = policy_id - else: - assert self.agent_to_policy[agent_id] == policy_id - - # Include the current agent id for multi-agent algorithms. - if agent_id != _DUMMY_AGENT_ID: - values["agent_id"] = agent_id - - # Add action/reward/next-obs (and other data) to Trajectory. - self.policy_sample_collectors[policy_id].add_action_reward_next_obs( - episode_id, agent_id, env_id, agent_done, values) - - @override(_SampleCollector) - def total_env_steps(self) -> int: - return sum(a.timesteps_since_last_reset - for a in self.policy_sample_collectors.values()) - - def total(self): - # TODO: (sven) deprecate; use `self.total_env_steps`, instead. - # Sampler is currently still using `total()`. - return self.total_env_steps() - - @override(_SampleCollector) - def get_inference_input_dict(self, policy_id: PolicyID) -> \ - Dict[str, TensorType]: - policy = self.policy_map[policy_id] - view_reqs = policy.model.inference_view_requirements - return self.policy_sample_collectors[ - policy_id].get_inference_input_dict(view_reqs) - - @override(_SampleCollector) - def has_non_postprocessed_data(self) -> bool: - return self.total_env_steps() > 0 - - @override(_SampleCollector) - def postprocess_trajectories_so_far( - self, episode: Optional[MultiAgentEpisode] = None) -> None: - # Loop through each per-policy collector and create a view (for each - # agent as SampleBatch) from its buffers for post-processing - all_agent_batches = {} - for pid, rc in self.policy_sample_collectors.items(): - policy = self.policy_map[pid] - view_reqs = policy.training_view_requirements - agent_batches = rc.get_postprocessing_sample_batches( - episode, view_reqs) - - for agent_key, batch in agent_batches.items(): - other_batches = None - if len(agent_batches) > 1: - other_batches = agent_batches.copy() - del other_batches[agent_key] - - agent_batches[agent_key] = policy.postprocess_trajectory( - batch, other_batches, episode) - # Call the Policy's Exploration's postprocess method. - if getattr(policy, "exploration", None) is not None: - agent_batches[ - agent_key] = policy.exploration.postprocess_trajectory( - policy, agent_batches[agent_key], - getattr(policy, "_sess", None)) - - # Add new columns' data to buffers. - for col in agent_batches[agent_key].new_columns: - data = agent_batches[agent_key].data[col] - rc._build_buffers({col: data[0]}) - timesteps = data.shape[0] - rc.buffers[col][rc.shift_before:rc.shift_before + - timesteps, rc.agent_key_to_slot[ - agent_key]] = data - - all_agent_batches.update(agent_batches) - - if log_once("after_post"): - logger.info("Trajectory fragment after postprocess_trajectory():" - "\n\n{}\n".format(summarize(all_agent_batches))) - - # Append into policy batches and reset - from ray.rllib.evaluation.rollout_worker import get_global_worker - for agent_key, batch in sorted(all_agent_batches.items()): - self.callbacks.on_postprocess_trajectory( - worker=get_global_worker(), - episode=episode, - agent_id=agent_key[0], - policy_id=self.agent_to_policy[agent_key[0]], - policies=self.policy_map, - postprocessed_batch=batch, - original_batches=None) # TODO: (sven) do we really need this? - - @override(_SampleCollector) - def check_missing_dones(self, episode_id: EpisodeID) -> None: - for pid, rc in self.policy_sample_collectors.items(): - for agent_key in rc.agent_key_to_slot.keys(): - # Only check for given episode and only for last chunk - # (all previous chunks for that agent in the episode are - # non-terminal). - if (agent_key[1] == episode_id - and rc.agent_key_to_chunk_num[agent_key[:2]] == - agent_key[2]): - t = rc.agent_key_to_timestep[agent_key] - 1 - b = rc.agent_key_to_slot[agent_key] - if not rc.buffers["dones"][t][b]: - raise ValueError( - "Episode {} terminated for all agents, but we " - "still don't have a last observation for " - "agent {} (policy {}). ".format(agent_key[0], pid) - + "Please ensure that you include the last " - "observations of all live agents when setting " - "'__all__' done to True. Alternatively, set " - "no_done_at_end=True to allow this.") - - @override(_SampleCollector) - def get_multi_agent_batch_and_reset(self): - self.postprocess_trajectories_so_far() - policy_batches = {} - for pid, rc in self.policy_sample_collectors.items(): - policy = self.policy_map[pid] - view_reqs = policy.training_view_requirements - policy_batches[pid] = rc.get_train_sample_batch_and_reset( - view_reqs) - - ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count) - # Reset our across-all-agents env step count. - self.count = 0 - return ma_batch diff --git a/rllib/evaluation/per_policy_sample_collector.py b/rllib/evaluation/per_policy_sample_collector.py deleted file mode 100644 index 3ef853ad5..000000000 --- a/rllib/evaluation/per_policy_sample_collector.py +++ /dev/null @@ -1,499 +0,0 @@ -import logging -import numpy as np -from typing import Dict, Optional - -from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType - -tf1, tf, tfv = try_import_tf() -torch, _ = try_import_torch() - -logger = logging.getLogger(__name__) - - -class _PerPolicySampleCollector: - """A class for efficiently collecting samples for a single (fixed) policy. - - Can be used by a _MultiAgentSampleCollector for its different policies. - """ - - def __init__(self, - num_agents: Optional[int] = None, - num_timesteps: Optional[int] = None, - time_major: bool = True, - shift_before: int = 0, - shift_after: int = 0): - """Initializes a _PerPolicySampleCollector object. - - Args: - num_agents (int): The max number of agent slots to pre-allocate - in the buffer. - num_timesteps (int): The max number of timesteps to pre-allocate - in the buffer. - time_major (Optional[bool]): Whether to preallocate buffers and - collect samples in time-major fashion (TxBx...). - shift_before (int): The additional number of time slots to - pre-allocate at the beginning of a time window (for possible - underlying data column shifts, e.g. PREV_ACTIONS). - shift_after (int): The additional number of time slots to - pre-allocate at the end of a time window (for possible - underlying data column shifts, e.g. NEXT_OBS). - """ - - self.num_agents = num_agents or 100 - self.num_timesteps = num_timesteps - self.time_major = time_major - # `shift_before must at least be 1 for the init obs timestep. - self.shift_before = max(shift_before, 1) - self.shift_after = shift_after - - # The offset on the agent dim to start the next SampleBatch build from. - self.sample_batch_offset = 0 - - # The actual underlying data-buffers. - self.buffers = {} - self.postprocessed_agents = [False] * self.num_agents - - # Next agent-slot to be used by a new agent/env combination. - self.agent_slot_cursor = 0 - # Maps agent/episode ID/chunk-num to an agent slot. - self.agent_key_to_slot = {} - # Maps agent/episode ID to the last chunk-num. - self.agent_key_to_chunk_num = {} - # Maps agent slot number to agent keys. - self.slot_to_agent_key = [None] * self.num_agents - # Maps agent/episode ID/chunk-num to a time step cursor. - self.agent_key_to_timestep = {} - - # Total timesteps taken in the env over all agents since last reset. - self.timesteps_since_last_reset = 0 - - # Indices (T,B) to pick from the buffers for the next forward pass. - self.forward_pass_indices = [[], []] - self.forward_pass_size = 0 - # Maps index from the forward pass batch to (agent_id, episode_id, - # env_id) tuple. - self.forward_pass_index_to_agent_info = {} - self.agent_key_to_forward_pass_index = {} - - def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, - env_id: EnvID, chunk_num: int, - init_obs: TensorType) -> None: - """Adds a single initial observation (after env.reset()) to the buffer. - - Args: - episode_id (EpisodeID): Unique ID for the episode we are adding the - initial observation for. - agent_id (AgentID): Unique ID for the agent we are adding the - initial observation for. - env_id (EnvID): The env ID to which `init_obs` belongs. - chunk_num (int): The time-chunk number (0-based). Some episodes - may last for longer than self.num_timesteps and therefore - have to be chopped into chunks. - init_obs (TensorType): Initial observation (after env.reset()). - """ - agent_key = (agent_id, episode_id, chunk_num) - agent_slot = self.agent_slot_cursor - self.agent_key_to_slot[agent_key] = agent_slot - self.agent_key_to_chunk_num[agent_key[:2]] = chunk_num - self.slot_to_agent_key[agent_slot] = agent_key - self._next_agent_slot() - - if SampleBatch.OBS not in self.buffers: - self._build_buffers( - single_row={ - SampleBatch.OBS: init_obs, - SampleBatch.EPS_ID: episode_id, - SampleBatch.AGENT_INDEX: agent_id, - "env_id": env_id, - }) - if self.time_major: - self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \ - init_obs - else: - self.buffers[SampleBatch.OBS][agent_slot, self.shift_before-1] = \ - init_obs - self.agent_key_to_timestep[agent_key] = self.shift_before - - self._add_to_next_inference_call(agent_key, env_id, agent_slot, - self.shift_before - 1) - - def add_action_reward_next_obs( - self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID, - agent_done: bool, values: Dict[str, TensorType]) -> None: - """Add the given dictionary (row) of values to this batch. - - Args: - episode_id (EpisodeID): Unique ID for the episode we are adding the - values for. - agent_id (AgentID): Unique ID for the agent we are adding the - values for. - env_id (EnvID): The env ID to which the given data belongs. - agent_done (bool): Whether next obs should not be used for an - upcoming inference call. Default: False = next-obs should be - used for upcoming inference. - values (Dict[str, TensorType]): Data dict (interpreted as a single - row) to be added to buffer. Must contain keys: - SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. - """ - assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values - and SampleBatch.NEXT_OBS in values - and SampleBatch.DONES in values) - - assert SampleBatch.OBS not in values - values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] - del values[SampleBatch.NEXT_OBS] - - chunk_num = self.agent_key_to_chunk_num[(agent_id, episode_id)] - agent_key = (agent_id, episode_id, chunk_num) - agent_slot = self.agent_key_to_slot[agent_key] - ts = self.agent_key_to_timestep[agent_key] - for k, v in values.items(): - if k not in self.buffers: - self._build_buffers(single_row=values) - if self.time_major: - self.buffers[k][ts, agent_slot] = v - else: - self.buffers[k][agent_slot, ts] = v - self.agent_key_to_timestep[agent_key] += 1 - - # Time-axis is "full" -> Cut-over to new chunk (only if not DONE). - if self.agent_key_to_timestep[ - agent_key] - self.shift_before == self.num_timesteps and \ - not values[SampleBatch.DONES]: - self._new_chunk_from(agent_slot, agent_key, - self.agent_key_to_timestep[agent_key]) - - self.timesteps_since_last_reset += 1 - - if not agent_done: - self._add_to_next_inference_call(agent_key, env_id, agent_slot, ts) - - def get_inference_input_dict(self, view_reqs: Dict[str, ViewRequirement] - ) -> Dict[str, TensorType]: - """Returns an input_dict for an (inference) forward pass. - - The input_dict can then be used for action computations inside a - Policy via `Policy.compute_actions_from_input_dict()`. - - Args: - view_reqs (Dict[str, ViewRequirement]): The view requirements - dict to use. - - Returns: - Dict[str, TensorType]: The input_dict to be passed into the ModelV2 - for inference/training. - - Examples: - >>> obs, r, done, info = env.step(action) - >>> collector.add_action_reward_next_obs(12345, 0, "pol0", { - ... "action": action, "obs": obs, "reward": r, "done": done - ... }) - >>> input_dict = collector.get_inference_input_dict(policy.model) - >>> action = policy.compute_actions_from_input_dict(input_dict) - >>> # repeat - """ - input_dict = {} - for view_col, view_req in view_reqs.items(): - # Create the batch of data from the different buffers. - data_col = view_req.data_col or view_col - if data_col not in self.buffers: - self._build_buffers({data_col: view_req.space.sample()}) - - indices = self.forward_pass_indices - if self.time_major: - input_dict[view_col] = self.buffers[data_col][indices] - else: - if isinstance(view_req.shift, (list, tuple)): - time_indices = \ - np.array(view_req.shift) + np.array(indices[0]) - input_dict[view_col] = self.buffers[data_col][indices[1], - time_indices] - else: - input_dict[view_col] = \ - self.buffers[data_col][indices[1], indices[0]] - - self._reset_inference_call() - - return input_dict - - def get_postprocessing_sample_batches( - self, - episode: MultiAgentEpisode, - view_reqs: Dict[str, ViewRequirement]) -> \ - Dict[AgentID, SampleBatch]: - """Returns a SampleBatch object ready for postprocessing. - - Args: - episode (MultiAgentEpisode): The MultiAgentEpisode object to - get the to-be-postprocessed SampleBatches for. - view_reqs (Dict[str, ViewRequirement]): The view requirements dict - to use for creating the SampleBatch from our buffers. - - Returns: - Dict[AgentID, SampleBatch]: The sample batch objects to be passed - to `Policy.postprocess_trajectory()`. - """ - # Loop through all agents and create a SampleBatch - # (as "view"; no copying). - - # Construct the SampleBatch-dict. - sample_batch_data = {} - - range_ = self.agent_slot_cursor - self.sample_batch_offset - if range_ < 0: - range_ = self.num_agents + range_ - for i in range(range_): - agent_slot = self.sample_batch_offset + i - if agent_slot >= self.num_agents: - agent_slot = agent_slot % self.num_agents - # Do not postprocess the same slot twice. - if self.postprocessed_agents[agent_slot]: - continue - agent_key = self.slot_to_agent_key[agent_slot] - # Skip other episodes (if episode provided). - if episode and agent_key[1] != episode.episode_id: - continue - end = self.agent_key_to_timestep[agent_key] - # Do not build any empty SampleBatches. - if end == self.shift_before: - continue - self.postprocessed_agents[agent_slot] = True - - assert agent_key not in sample_batch_data - sample_batch_data[agent_key] = {} - batch = sample_batch_data[agent_key] - - for view_col, view_req in view_reqs.items(): - data_col = view_req.data_col or view_col - # Skip columns that will only get added through postprocessing - # (these may not even exist yet). - if data_col not in self.buffers: - continue - - shift = view_req.shift - if data_col == SampleBatch.OBS: - shift -= 1 - - batch[view_col] = self.buffers[data_col][ - self.shift_before + shift:end + shift, agent_slot] - - batches = {} - for agent_key, data in sample_batch_data.items(): - batches[agent_key] = SampleBatch(data) - return batches - - def get_train_sample_batch_and_reset(self, view_reqs) -> SampleBatch: - """Returns the accumulated sample batche for this policy. - - This is usually called to collect samples for policy training. - - Returns: - SampleBatch: Returns the accumulated sample batch for this - policy. - """ - seq_lens_w_0s = [ - self.agent_key_to_timestep[k] - self.shift_before - for k in self.slot_to_agent_key if k is not None - ] - # We have an agent-axis buffer "rollover" (new SampleBatch will be - # built from last n agent records plus first m agent records in - # buffer). - if self.agent_slot_cursor < self.sample_batch_offset: - rollover = -(self.num_agents - self.sample_batch_offset) - seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover] - first_zero_len = len(seq_lens_w_0s) - if seq_lens_w_0s[-1] == 0: - first_zero_len = seq_lens_w_0s.index(0) - # Assert that all zeros lie at the end of the seq_lens array. - assert all(seq_lens_w_0s[i] == 0 - for i in range(first_zero_len, len(seq_lens_w_0s))) - - t_start = self.shift_before - t_end = t_start + self.num_timesteps - - # The agent_slot cursor that points to the newest agent-slot that - # actually already has at least 1 timestep of data (thus it excludes - # just-rolled over chunks (which only have the initial obs in them)). - valid_agent_cursor = \ - (self.agent_slot_cursor - - (len(seq_lens_w_0s) - first_zero_len)) % self.num_agents - - # Construct the view dict. - view = {} - for view_col, view_req in view_reqs.items(): - data_col = view_req.data_col or view_col - assert data_col in self.buffers - # For OBS, indices must be shifted by -1. - shift = view_req.shift - shift += 0 if data_col != SampleBatch.OBS else -1 - # If agent_slot has been rolled-over to beginning, we have to copy - # here. - if valid_agent_cursor < self.sample_batch_offset: - time_slice = self.buffers[data_col][t_start + shift:t_end + - shift] - one_ = time_slice[:, self.sample_batch_offset:] - two_ = time_slice[:, :valid_agent_cursor] - if torch and isinstance(time_slice, torch.Tensor): - view[view_col] = torch.cat([one_, two_], dim=1) - else: - view[view_col] = np.concatenate([one_, two_], axis=1) - else: - view[view_col] = \ - self.buffers[data_col][ - t_start + shift:t_end + shift, - self.sample_batch_offset:valid_agent_cursor] - - # Copy all still ongoing trajectories to new agent slots - # (including the ones that just started (are seq_len=0)). - new_chunk_args = [] - for i, seq_len in enumerate(seq_lens_w_0s): - if seq_len < self.num_timesteps: - agent_slot = (self.sample_batch_offset + i) % self.num_agents - if not self.buffers[SampleBatch. - DONES][seq_len - 1 + - self.shift_before][agent_slot]: - agent_key = self.slot_to_agent_key[agent_slot] - new_chunk_args.append( - (agent_slot, agent_key, - self.agent_key_to_timestep[agent_key])) - # Cut out all 0 seq-lens. - seq_lens = seq_lens_w_0s[:first_zero_len] - batch = SampleBatch( - view, _seq_lens=np.array(seq_lens), _time_major=self.time_major) - - # Reset everything for new data. - self.postprocessed_agents = [False] * self.num_agents - self.agent_key_to_slot.clear() - self.agent_key_to_chunk_num.clear() - self.slot_to_agent_key = [None] * self.num_agents - self.agent_key_to_timestep.clear() - self.timesteps_since_last_reset = 0 - self.forward_pass_size = 0 - self.sample_batch_offset = self.agent_slot_cursor - - for args in new_chunk_args: - self._new_chunk_from(*args) - - return batch - - def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: - """Builds the internal data buffers based on a single given row. - - This may be called several times in the lifetime of this instance - to add new columns to the buffer. Columns in `single_row` that already - exist in the buffer will be ignored. - - Args: - single_row (Dict[str, TensorType]): A single datarow with one or - more columns (str as key, np.ndarray|tensor as data) to be used - as template to build the pre-allocated buffer. - """ - time_size = self.num_timesteps + self.shift_before + self.shift_after - for col, data in single_row.items(): - if col in self.buffers: - continue - base_shape = (time_size, self.num_agents) if self.time_major else \ - (self.num_agents, time_size) - # Python primitive -> np.array. - if isinstance(data, (int, float, bool)): - t_ = type(data) - dtype = np.float32 if t_ == float else \ - np.int32 if type(data) == int else np.bool_ - self.buffers[col] = np.zeros(shape=base_shape, dtype=dtype) - # np.ndarray, torch.Tensor, or tf.Tensor. - else: - shape = base_shape + data.shape - dtype = data.dtype - if torch and isinstance(data, torch.Tensor): - self.buffers[col] = torch.zeros( - *shape, dtype=dtype, device=data.device) - elif tf and isinstance(data, tf.Tensor): - self.buffers[col] = tf.zeros(shape=shape, dtype=dtype) - else: - self.buffers[col] = np.zeros(shape=shape, dtype=dtype) - - def _next_agent_slot(self): - """Starts a new agent slot at the end of the agent-axis. - - Also makes sure, the new slot is not taken yet. - """ - self.agent_slot_cursor += 1 - if self.agent_slot_cursor >= self.num_agents: - self.agent_slot_cursor = 0 - # Just make sure, there is space in our buffer. - assert self.slot_to_agent_key[self.agent_slot_cursor] is None - - def _new_chunk_from(self, agent_slot, agent_key, timestep): - """Creates a new time-window (chunk) given an agent. - - The agent may already have an unfinished episode going on (in a - previous chunk). The end of that previous chunk will be copied to the - beginning of the new one for proper data-shift handling (e.g. - PREV_ACTIONS/REWARDS). - - Args: - agent_slot (int): The agent to start a new chunk for (from an - ongoing episode (chunk)). - agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to - identify an active agent in some episode. - timestep (int): The timestep in the old chunk being continued. - """ - new_agent_slot = self.agent_slot_cursor - # Increase chunk num by 1. - new_agent_key = agent_key[:2] + (agent_key[2] + 1, ) - # Copy relevant timesteps at end of old chunk into new one. - if self.time_major: - for k in self.buffers.keys(): - self.buffers[k][0:self.shift_before, new_agent_slot] = \ - self.buffers[k][ - timestep - self.shift_before:timestep, agent_slot] - else: - for k in self.buffers.keys(): - self.buffers[k][new_agent_slot, 0:self.shift_before] = \ - self.buffers[k][ - agent_slot, timestep - self.shift_before:timestep] - - self.agent_key_to_slot[new_agent_key] = new_agent_slot - self.agent_key_to_chunk_num[new_agent_key[:2]] = new_agent_key[2] - self.slot_to_agent_key[new_agent_slot] = new_agent_key - self._next_agent_slot() - self.agent_key_to_timestep[new_agent_key] = self.shift_before - - def _add_to_next_inference_call(self, agent_key, env_id, agent_slot, - timestep): - """Registers given T and B (agent_slot) for get_inference_input_dict. - - Calling `get_inference_input_dict` will produce an input_dict (for - Policy.compute_actions_from_input_dict) with all registered agent/time - indices and then automatically reset the registry. - - Args: - agent_key (Tuple[AgentID, EpisodeID, int]): The internal key to - identify an active agent in some episode. - env_id (EnvID): The env ID of the given agent. - agent_slot (int): The agent_slot to register (B axis). - timestep (int): The timestep to register (T axis). - """ - idx = self.forward_pass_size - self.forward_pass_index_to_agent_info[idx] = (agent_key[0], - agent_key[1], env_id) - self.agent_key_to_forward_pass_index[agent_key[:2]] = idx - if self.forward_pass_size == 0: - self.forward_pass_indices[0].clear() - self.forward_pass_indices[1].clear() - self.forward_pass_indices[0].append(timestep) - self.forward_pass_indices[1].append(agent_slot) - self.forward_pass_size += 1 - - def _reset_inference_call(self): - """Resets indices for the next inference call. - - After calling this, new calls to `add_init_obs()` and - `add_action_reward_next_obs()` will count for the next input_dict - returned by `get_inference_input_dict()`. - """ - self.forward_pass_size = 0 diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 52702c180..9782f303f 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -9,13 +9,14 @@ from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\ TYPE_CHECKING, Union from ray.util.debug import log_once +from ray.rllib.evaluation.collectors.sample_collector import \ + _SampleCollector +from ray.rllib.evaluation.collectors.simple_list_collector import \ + _SimpleListCollector from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.evaluation.multi_agent_sample_collector import \ - _MultiAgentSampleCollector from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder -from ray.rllib.evaluation.sample_collector import _SampleCollector from ray.rllib.policy.policy import clip_action, Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.preprocessors import Preprocessor @@ -188,8 +189,9 @@ class SyncSampler(SamplerInput): self.extra_batches = queue.Queue() self.perf_stats = _PerfStats() if _use_trajectory_view_api: - self.sample_collector = _MultiAgentSampleCollector( - policies, callbacks) + self.sample_collector = _SimpleListCollector( + policies, clip_rewards, callbacks, multiple_episodes_in_batch, + rollout_fragment_length) else: self.sample_collector = None @@ -333,8 +335,9 @@ class AsyncSampler(threading.Thread, SamplerInput): self.observation_fn = observation_fn self._use_trajectory_view_api = _use_trajectory_view_api if _use_trajectory_view_api: - self.sample_collector = _MultiAgentSampleCollector( - policies, callbacks) + self.sample_collector = _SimpleListCollector( + policies, clip_rewards, callbacks, multiple_episodes_in_batch, + rollout_fragment_length) else: self.sample_collector = None @@ -537,7 +540,6 @@ def _env_runner( active_episodes: Dict[str, MultiAgentEpisode] = \ NewEpisodeDefaultDict(new_episode) - eval_results = None while True: perf_stats.iters += 1 @@ -564,7 +566,6 @@ def _env_runner( base_env=base_env, policies=policies, active_episodes=active_episodes, - prev_policy_outputs=eval_results, unfiltered_obs=unfiltered_obs, rewards=rewards, dones=dones, @@ -572,13 +573,11 @@ def _env_runner( horizon=horizon, preprocessors=preprocessors, obs_filters=obs_filters, - rollout_fragment_length=rollout_fragment_length, multiple_episodes_in_batch=multiple_episodes_in_batch, callbacks=callbacks, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, - perf_stats=perf_stats, _sample_collector=_sample_collector, ) else: @@ -601,7 +600,6 @@ def _env_runner( soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, - perf_stats=perf_stats, ) perf_stats.raw_obs_processing_time += time.time() - t1 for o in outputs: @@ -669,7 +667,6 @@ def _process_observations( soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", - perf_stats: _PerfStats, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. @@ -682,8 +679,6 @@ def _process_observations( SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. - prev_policy_outputs (Dict[str,List]): The prev policy output dict - (by policy-id -> List[action, state outs, extra fetches]). unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> unfiltered observation tensor, returned by a `BaseEnv.poll()` call. @@ -862,6 +857,7 @@ def _process_observations( # and add it to "outputs". if (all_agents_done and not multiple_episodes_in_batch) or \ batch_builder.count >= rollout_fragment_length: + batch_builder.postprocess_batch_so_far(episode) outputs.append(batch_builder.build_and_reset(episode)) # Make sure postprocessor stays within one episode. elif all_agents_done: @@ -887,10 +883,14 @@ def _process_observations( episode=episode, env_index=env_id, ) + # Horizon hit and we have a soft horizon (no hard env reset). if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = agent_obs + # Env actually ended OR horizon hit and no soft horizon -> + # Try hard env-reset. else: + # Remove episode from active ones. del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset( env_id) @@ -939,8 +939,6 @@ def _process_observations_w_trajectory_view_api( base_env: BaseEnv, policies: Dict[PolicyID, Policy], active_episodes: Dict[str, MultiAgentEpisode], - prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch, - dict]], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], rewards: Dict[EnvID, Dict[AgentID, float]], dones: Dict[EnvID, Dict[AgentID, bool]], @@ -948,13 +946,11 @@ def _process_observations_w_trajectory_view_api( horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], - rollout_fragment_length: int, multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", - perf_stats: _PerfStats, _sample_collector: _SampleCollector, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: @@ -964,41 +960,20 @@ def _process_observations_w_trajectory_view_api( # Output objects. active_envs: Set[EnvID] = set() - to_eval: Set[PolicyID] = set() + to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] - large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \ - rollout_fragment_length != float("inf") else 5000 - - # For each environment. + # For each (vectorized) sub-environment. # type: EnvID, Dict[AgentID, EnvObsType] - for env_id, agent_obs in unfiltered_obs.items(): + for env_id, all_agents_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: + _sample_collector.episode_step(episode.episode_id) episode.length += 1 - _sample_collector.count += 1 episode._add_agent_rewards(rewards[env_id]) - if (_sample_collector.total_env_steps() > large_batch_threshold - and log_once("large_batch_warning")): - logger.warning( - "More than {} observations for {} env steps ".format( - _sample_collector.total_env_steps(), - _sample_collector.count) + - "are buffered in the sampler. If this is more than you " - "expected, check that that you set a horizon on your " - "environment correctly and that it terminates at some point. " - "Note: In multi-agent environments, `rollout_fragment_length` " - "sets the batch size based on (across-agents) environment " - "steps, not the steps of individual agents, which can result " - "in unexpectedly large batches." + - ("Also, you may be in evaluation waiting for your Env to " - "terminate (batch_mode=`complete_episodes`). Make sure it " - "does at some point." - if not multiple_episodes_in_batch else "")) - # Check episode termination conditions. if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon @@ -1023,19 +998,19 @@ def _process_observations_w_trajectory_view_api( # Custom observation function is applied before preprocessing. if observation_fn: - agent_obs: Dict[AgentID, EnvObsType] = observation_fn( - agent_obs=agent_obs, + all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn( + agent_obs=all_agents_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) - if not isinstance(agent_obs, dict): + if not isinstance(all_agents_obs, dict): raise ValueError( "observe() must return a dict of agent observations") # For each agent in the environment. # type: AgentID, EnvObsType - for agent_id, raw_obs in agent_obs.items(): + for agent_id, raw_obs in all_agents_obs.items(): assert agent_id != "__all__" policy_id: PolicyID = episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise(preprocessors, @@ -1058,38 +1033,41 @@ def _process_observations_w_trajectory_view_api( # Record transition info if applicable. if last_observation is None: - _sample_collector.add_init_obs(episode.episode_id, agent_id, - env_id, policy_id, filtered_obs) + _sample_collector.add_init_obs(episode, agent_id, env_id, + policy_id, filtered_obs) else: - rc = _sample_collector.policy_sample_collectors[policy_id] - eval_idx = rc.agent_key_to_forward_pass_index[( - agent_id, episode.episode_id)] + # Add actions, rewards, next-obs to collectors. values_dict = { "t": episode.length - 1, "eps_id": episode.episode_id, + "env_id": env_id, "agent_index": episode._agent_index(agent_id), # Action (slot 0) taken at timestep t. - "actions": prev_policy_outputs[policy_id][0][eval_idx], + "actions": episode.last_action_for(agent_id), # Reward received after taking a at timestep t. "rewards": rewards[env_id][agent_id], - # After taking a, did we reach terminal? + # After taking action=a, did we reach terminal? "dones": (False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), # Next observation. "new_obs": filtered_obs, } - # TODO: (sven) add env infos to buffers as well. - for k, v in prev_policy_outputs[policy_id][2].items(): - values_dict[k] = v[eval_idx] - for i, v in enumerate(prev_policy_outputs[policy_id][1]): - values_dict["state_out_{}".format(i)] = v[eval_idx] + # Add extra-action-fetches to collectors. + values_dict.update(**episode.last_pi_info_for(agent_id)) _sample_collector.add_action_reward_next_obs( episode.episode_id, agent_id, env_id, policy_id, agent_done, values_dict) if not agent_done: - to_eval.add(policy_id) + item = PolicyEvalData( + env_id, agent_id, filtered_obs, infos[env_id].get( + agent_id, {}), None if last_observation is None else + episode.rnn_state_for(agent_id), None + if last_observation is None else + episode.last_action_for(agent_id), + rewards[env_id][agent_id] or 0.0) + to_eval[policy_id].append(item) # Invoke the step callback after the step is logged to the episode callbacks.on_episode_step( @@ -1098,36 +1076,22 @@ def _process_observations_w_trajectory_view_api( episode=episode, env_index=env_id) - # Cut the batch if ... - # - all-agents-done and not packing multiple episodes into one - # (batch_mode="complete_episodes") - # - or if we've exceeded the rollout_fragment_length. - if _sample_collector.has_non_postprocessed_data(): - # Sanity check, whether all agents have done=True, if done[__all__] - # is True. - if dones[env_id]["__all__"] and not no_done_at_end: - _sample_collector.check_missing_dones( - episode_id=episode.episode_id) - - # Reached end of episode and we are not allowed to pack the - # next episode into the same SampleBatch -> Build the SampleBatch - # and add it to "outputs". - if (all_agents_done and not multiple_episodes_in_batch) or \ - _sample_collector.count >= rollout_fragment_length: - # TODO: (sven) Case: rollout_fragment_length reached: Do not - # store any data in `episode` anymore - # (useless for get_view_requirements when t<<-1, e.g. - # attention), but keep last episode data around in - # SampleBatchBuilder - # to be able to still reference into it - # should a model require this. - outputs.append(_sample_collector.get_multi_agent_batch_and_reset()) + # Episode is done for all agents + # (dones[__all__] == True or hit horizon). # Make sure postprocessor stays within one episode. - elif all_agents_done: - _sample_collector.postprocess_trajectories_so_far(episode) - - # Episode is done. if all_agents_done: + is_done = dones[env_id]["__all__"] + check_dones = is_done and not no_done_at_end + _sample_collector.postprocess_episode( + episode, is_done=is_done, check_dones=check_dones) + # We are not allowed to pack the next episode into the same + # SampleBatch (batch_mode=complete_episodes) -> Build the + # MultiAgentBatch from a single episode and add it to "outputs". + if not multiple_episodes_in_batch: + ma_sample_batch = \ + _sample_collector.build_multi_agent_batch(episode.length) + outputs.append(ma_sample_batch) + # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: @@ -1144,44 +1108,56 @@ def _process_observations_w_trajectory_view_api( episode=episode, env_index=env_id, ) + # Horizon hit and we have a soft horizon (no hard env reset). if hit_horizon and soft_horizon: episode.soft_reset() - resetted_obs: Dict[AgentID, EnvObsType] = agent_obs + resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs else: del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset( env_id) + # Reset not supported, drop this env from the ready list. if resetted_obs is None: - # Reset not supported, drop this env from the ready list. if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") + # Creates a new episode if this is not async return. + # If reset is async, we will get its result in some future poll. elif resetted_obs != ASYNC_RESET_RETURN: - # Creates a new episode if this is not async return. - # If reset is async, we will get its result in some future poll - episode: MultiAgentEpisode = active_episodes[env_id] + new_episode: MultiAgentEpisode = active_episodes[env_id] if observation_fn: resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=resetted_obs, worker=worker, base_env=base_env, policies=policies, - episode=episode) + episode=new_episode) # type: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): - policy_id: PolicyID = episode.policy_for(agent_id) + policy_id: PolicyID = new_episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise( preprocessors, policy_id).transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) - episode._set_last_observation(agent_id, filtered_obs) + new_episode._set_last_observation(agent_id, filtered_obs) # Add initial obs to buffer. - _sample_collector.add_init_obs(episode.episode_id, - agent_id, env_id, policy_id, - filtered_obs) - to_eval.add(policy_id) + _sample_collector.add_init_obs( + new_episode, agent_id, env_id, policy_id, filtered_obs) + + item = PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.last_info_for(agent_id) or {}, + episode.rnn_state_for(agent_id), None, 0.0) + to_eval[policy_id].append(item) + + # Try to build something. + if multiple_episodes_in_batch: + sample_batch = \ + _sample_collector.try_build_truncated_episode_multi_agent_batch() + if sample_batch is not None: + outputs.append(sample_batch) return active_envs, to_eval, outputs @@ -1306,7 +1282,7 @@ def _do_policy_eval_w_trajectory_view_api( logger.info("Inputs to compute_actions():\n\n{}\n".format( summarize(to_eval))) - for policy_id in to_eval: + for policy_id in to_eval.keys(): policy: Policy = _get_or_raise(policies, policy_id) input_dict = _sample_collector.get_inference_input_dict(policy_id) eval_results[policy_id] = \ @@ -1373,7 +1349,7 @@ def _process_policy_eval_results( actions_to_send[env_id] = {} # at minimum send empty dict # type: PolicyID, List[PolicyEvalData] - for policy_id in to_eval: + for policy_id, eval_data in to_eval.items(): actions: TensorStructType = eval_results[policy_id][0] actions = convert_to_numpy(actions) @@ -1385,10 +1361,11 @@ def _process_policy_eval_results( if isinstance(actions, list): actions = np.array(actions) - # Add RNN state info. - eval_data = None - if not _use_trajectory_view_api: - eval_data = to_eval[policy_id] + # Store RNN state ins/outs and extra-action fetches to episode. + if _use_trajectory_view_api: + for f_i, column in enumerate(rnn_out_cols): + pi_info_cols["state_out_{}".format(f_i)] = column + else: rnn_in_cols: StateBatch = _to_column_format( [t.rnn_state for t in eval_data]) @@ -1413,27 +1390,19 @@ def _process_policy_eval_results( else: clipped_action = action - # Trajectory View API: Do not store data directly in episode - # (entire episode is stored in Trajectory and kept until - # end of episode). - if _use_trajectory_view_api: - agent_id, episode_id, env_id = \ - _sample_collector.policy_sample_collectors[ - policy_id].forward_pass_index_to_agent_info[i] + env_id: int = eval_data[i].env_id + agent_id: AgentID = eval_data[i].agent_id + episode: MultiAgentEpisode = active_episodes[env_id] + episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode._set_last_pi_info( + agent_id, {k: v[i] + for k, v in pi_info_cols.items()}) + if env_id in off_policy_actions and \ + agent_id in off_policy_actions[env_id]: + episode._set_last_action(agent_id, + off_policy_actions[env_id][agent_id]) else: - env_id: int = eval_data[i].env_id - agent_id: AgentID = eval_data[i].agent_id - episode: MultiAgentEpisode = active_episodes[env_id] - episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) - episode._set_last_pi_info( - agent_id, {k: v[i] - for k, v in pi_info_cols.items()}) - if env_id in off_policy_actions and \ - agent_id in off_policy_actions[env_id]: - episode._set_last_action( - agent_id, off_policy_actions[env_id][agent_id]) - else: - episode._set_last_action(agent_id, action) + episode._set_last_action(agent_id, action) assert agent_id not in actions_to_send[env_id] actions_to_send[env_id][agent_id] = clipped_action diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 4a1822da3..0125b3f59 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -9,8 +9,9 @@ from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ EpisodeEnvAwarePolicy +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.test_utils import framework_iterator +from ray.rllib.utils.test_utils import framework_iterator, check class TestTrajectoryViewAPI(unittest.TestCase): @@ -30,9 +31,9 @@ class TestTrajectoryViewAPI(unittest.TestCase): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements - view_req_policy = policy.training_view_requirements - assert len(view_req_model) == 1 - assert len(view_req_policy) == 10 + view_req_policy = policy.view_requirements + assert len(view_req_model) == 1, view_req_model + assert len(view_req_policy) == 11, view_req_policy for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, @@ -62,9 +63,9 @@ class TestTrajectoryViewAPI(unittest.TestCase): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements - view_req_policy = policy.training_view_requirements - assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates - assert len(view_req_policy) == 16 + view_req_policy = policy.view_requirements + assert len(view_req_model) == 7, view_req_model + assert len(view_req_policy) == 17, view_req_policy for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, @@ -90,7 +91,7 @@ class TestTrajectoryViewAPI(unittest.TestCase): assert view_req_policy[key].shift == 1 trainer.stop() - def test_traj_view_lstm_performance(self): + def test_traj_view_simple_performance(self): """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. """ config = copy.deepcopy(ppo.DEFAULT_CONFIG) @@ -102,17 +103,15 @@ class TestTrajectoryViewAPI(unittest.TestCase): from ray.tune import register_env register_env("ma_env", lambda c: RandomMultiAgentEnv({ "num_agents": 2, - "p_done": 0.01, + "p_done": 0.0, + "max_episode_len": 104, "action_space": action_space, "observation_space": obs_space })) config["num_workers"] = 3 config["num_envs_per_worker"] = 8 - config["num_sgd_iter"] = 6 - config["model"]["use_lstm"] = True - config["model"]["lstm_use_prev_action_reward"] = True - config["model"]["max_seq_len"] = 100 + config["num_sgd_iter"] = 1 # Put less weight on training. policies = { "pol0": (None, obs_space, action_space, {}), @@ -125,72 +124,80 @@ class TestTrajectoryViewAPI(unittest.TestCase): "policies": policies, "policy_mapping_fn": policy_fn, } - num_iterations = 1 + num_iterations = 2 # Only works in torch so far. for _ in framework_iterator(config, frameworks="torch"): - print("w/ traj. view API (and time-major)") + print("w/ traj. view API") config["_use_trajectory_view_api"] = True - config["model"]["_time_major"] = True trainer = ppo.PPOTrainer(config=config, env="ma_env") learn_time_w = 0.0 - sampler_perf = {} + sampler_perf_w = {} start = time.time() for i in range(num_iterations): out = trainer.train() + ts = out["timesteps_total"] sampler_perf_ = out["sampler_perf"] - sampler_perf = { - k: sampler_perf.get(k, 0.0) + sampler_perf_[k] + sampler_perf_w = { + k: + sampler_perf_w.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts) for k, v in sampler_perf_.items() } - delta = out["timers"]["learn_time_ms"] / 1000 + delta = out["timers"]["learn_time_ms"] / ts learn_time_w += delta print("{}={}s".format(i, delta)) - sampler_perf = { - k: sampler_perf[k] / (num_iterations if "mean_" in k else 1) - for k, v in sampler_perf.items() + sampler_perf_w = { + k: sampler_perf_w[k] / (num_iterations if "mean_" in k else 1) + for k, v in sampler_perf_w.items() } duration_w = time.time() - start print("Duration: {}s " "sampler-perf.={} learn-time/iter={}s".format( - duration_w, sampler_perf, learn_time_w / num_iterations)) + duration_w, sampler_perf_w, + learn_time_w / num_iterations)) trainer.stop() - print("w/o traj. view API (and w/o time-major)") + print("w/o traj. view API") config["_use_trajectory_view_api"] = False - config["model"]["_time_major"] = False trainer = ppo.PPOTrainer(config=config, env="ma_env") learn_time_wo = 0.0 - sampler_perf = {} + sampler_perf_wo = {} start = time.time() for i in range(num_iterations): out = trainer.train() + ts = out["timesteps_total"] sampler_perf_ = out["sampler_perf"] - sampler_perf = { - k: sampler_perf.get(k, 0.0) + sampler_perf_[k] + sampler_perf_wo = { + k: sampler_perf_wo.get(k, 0.0) + + (sampler_perf_[k] * 1000 / ts) for k, v in sampler_perf_.items() } - delta = out["timers"]["learn_time_ms"] / 1000 + delta = out["timers"]["learn_time_ms"] / ts learn_time_wo += delta print("{}={}s".format(i, delta)) - sampler_perf = { - k: sampler_perf[k] / (num_iterations if "mean_" in k else 1) - for k, v in sampler_perf.items() + sampler_perf_wo = { + k: sampler_perf_wo[k] / (num_iterations if "mean_" in k else 1) + for k, v in sampler_perf_wo.items() } duration_wo = time.time() - start print("Duration: {}s " "sampler-perf.={} learn-time/iter={}s".format( - duration_wo, sampler_perf, + duration_wo, sampler_perf_wo, learn_time_wo / num_iterations)) trainer.stop() - # Assert `_use_trajectory_view_api` is much faster. + # Assert `_use_trajectory_view_api` is faster. + self.assertLess(sampler_perf_w["mean_raw_obs_processing_ms"], + sampler_perf_wo["mean_raw_obs_processing_ms"]) + self.assertLess(sampler_perf_w["mean_action_processing_ms"], + sampler_perf_wo["mean_action_processing_ms"]) self.assertLess(duration_w, duration_wo) - self.assertLess(learn_time_w, learn_time_wo * 0.6) def test_traj_view_lstm_functionality(self): - action_space = Box(-float("inf"), float("inf"), shape=(2, )) + action_space = Box(-float("inf"), float("inf"), shape=(3, )) obs_space = Box(float("-inf"), float("inf"), (4, )) max_seq_len = 50 + rollout_fragment_length = 200 + assert rollout_fragment_length % max_seq_len == 0 policies = { "pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}), } @@ -198,77 +205,162 @@ class TestTrajectoryViewAPI(unittest.TestCase): def policy_fn(agent_id): return "pol0" - rollout_worker = RolloutWorker( - env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), - policy_config={ - "multiagent": { - "policies": policies, - "policy_mapping_fn": policy_fn, - }, - "_use_trajectory_view_api": True, - "model": { - "use_lstm": True, - "_time_major": True, - "max_seq_len": max_seq_len, - }, + config = { + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_fn, }, + "model": { + "use_lstm": True, + "max_seq_len": max_seq_len, + }, + }, + + rollout_worker_w_api = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config=dict(config, **{"_use_trajectory_view_api": True}), + rollout_fragment_length=rollout_fragment_length, policy=policies, policy_mapping_fn=policy_fn, num_envs=1, ) - for i in range(100): - pc = rollout_worker.sampler.sample_collector. \ - policy_sample_collectors["pol0"] - sample_batch_offset_before = pc.sample_batch_offset - buffers = pc.buffers - result = rollout_worker.sample() - pol_batch = result.policy_batches["pol0"] + rollout_worker_wo_api = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config=dict(config, **{"_use_trajectory_view_api": False}), + rollout_fragment_length=rollout_fragment_length, + policy=policies, + policy_mapping_fn=policy_fn, + num_envs=1, + ) + for iteration in range(20): + result = rollout_worker_w_api.sample() + check(result.count, rollout_fragment_length) + pol_batch_w = result.policy_batches["pol0"] + assert pol_batch_w.count >= rollout_fragment_length + analyze_rnn_batch(pol_batch_w, max_seq_len) - self.assertTrue(result.count == 100) - self.assertTrue(pol_batch.count >= 100) - self.assertFalse(0 in pol_batch.seq_lens) - # Check prev_reward/action, next_obs consistency. - for t in range(max_seq_len): - obs_t = pol_batch["obs"][t] - r_t = pol_batch["rewards"][t] - if t > 0: - next_obs_t_m_1 = pol_batch["new_obs"][t - 1] - self.assertTrue((obs_t == next_obs_t_m_1).all()) - if t < max_seq_len - 1: - prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1] - self.assertTrue((r_t == prev_rewards_t_p_1).all()) + result = rollout_worker_wo_api.sample() + pol_batch_wo = result.policy_batches["pol0"] + check(pol_batch_w.data, pol_batch_wo.data) - # Check the sanity of all the buffers in the un underlying - # PerPolicy collector. - for sample_batch_slot, agent_slot in enumerate( - range(sample_batch_offset_before, pc.sample_batch_offset)): - t_buf = buffers["t"][:, agent_slot] - obs_buf = buffers["obs"][:, agent_slot] - # Skip empty seqs at end (these won't be part of the batch - # and have been copied to new agent-slots (even if seq-len=0)). - if sample_batch_slot < len(pol_batch.seq_lens): - seq_len = pol_batch.seq_lens[sample_batch_slot] - # Make sure timesteps are always increasing within the seq. - assert all(t_buf[1] + j == n + 1 - for j, n in enumerate(t_buf) - if j < seq_len and j != 0) - # Make sure all obs within seq are non-0.0. - assert all( - any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1)) - # Check seq-lens. - for agent_slot, seq_len in enumerate(pol_batch.seq_lens): - if seq_len < max_seq_len - 1: - # At least in the beginning, the next slots should always - # be empty (once all agent slots have been used once, these - # may be filled with "old" values (from longer sequences)). - if i < 10: - self.assertTrue( - (pol_batch["obs"][seq_len + - 1][agent_slot] == 0.0).all()) - print(end="") - self.assertFalse( - (pol_batch["obs"][seq_len][agent_slot] == 0.0).all()) +def analyze_rnn_batch(batch, max_seq_len): + count = batch.count + + # Check prev_reward/action, next_obs consistency. + for idx in range(count): + # If timestep tracked by batch, good. + if "t" in batch: + ts = batch["t"][idx] + # Else, ts + else: + ts = batch["obs"][idx][3] + obs_t = batch["obs"][idx] + a_t = batch["actions"][idx] + r_t = batch["rewards"][idx] + state_in_0 = batch["state_in_0"][idx] + state_in_1 = batch["state_in_1"][idx] + + # Check postprocessing outputs. + if "postprocessed_column" in batch: + postprocessed_col_t = batch["postprocessed_column"][idx] + assert (obs_t == postprocessed_col_t / 2.0).all() + + # Check state-in/out and next-obs values. + if idx > 0: + next_obs_t_m_1 = batch["new_obs"][idx - 1] + state_out_0_t_m_1 = batch["state_out_0"][idx - 1] + state_out_1_t_m_1 = batch["state_out_1"][idx - 1] + # Same trajectory as for t-1 -> Should be able to match. + if (batch[SampleBatch.AGENT_INDEX][idx] == + batch[SampleBatch.AGENT_INDEX][idx - 1] + and batch[SampleBatch.EPS_ID][idx] == + batch[SampleBatch.EPS_ID][idx - 1]): + assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx] + assert (obs_t == next_obs_t_m_1).all() + assert (state_in_0 == state_out_0_t_m_1).all() + assert (state_in_1 == state_out_1_t_m_1).all() + # Different trajectory. + else: + assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx] + assert not (obs_t == next_obs_t_m_1).all() + assert not (state_in_0 == state_out_0_t_m_1).all() + assert not (state_in_1 == state_out_1_t_m_1).all() + # Check initial 0-internal states. + if ts == 0: + assert (state_in_0 == 0.0).all() + assert (state_in_1 == 0.0).all() + + # Check initial 0-internal states (at ts=0). + if ts == 0: + assert (state_in_0 == 0.0).all() + assert (state_in_1 == 0.0).all() + + # Check prev. a/r values. + if idx < count - 1: + prev_actions_t_p_1 = batch["prev_actions"][idx + 1] + prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1] + # Same trajectory as for t+1 -> Should be able to match. + if batch[SampleBatch.AGENT_INDEX][idx] == \ + batch[SampleBatch.AGENT_INDEX][idx + 1] and \ + batch[SampleBatch.EPS_ID][idx] == \ + batch[SampleBatch.EPS_ID][idx + 1]: + assert (a_t == prev_actions_t_p_1).all() + assert r_t == prev_rewards_t_p_1 + # Different (new) trajectory. Assume t-1 (prev-a/r) to be + # always 0.0s. [3]=ts + elif ts == 0: + assert (prev_actions_t_p_1 == 0).all() + assert prev_rewards_t_p_1 == 0.0 + + pad_batch_to_sequences_of_same_size( + batch, + max_seq_len=max_seq_len, + shuffle=False, + batch_divisibility_req=1) + + # Check after seq-len 0-padding. + cursor = 0 + for i, seq_len in enumerate(batch["seq_lens"]): + state_in_0 = batch["state_in_0"][i] + state_in_1 = batch["state_in_1"][i] + for j in range(seq_len): + k = cursor + j + ts = batch["t"][k] + obs_t = batch["obs"][k] + a_t = batch["actions"][k] + r_t = batch["rewards"][k] + + # Check postprocessing outputs. + if "postprocessed_column" in batch: + postprocessed_col_t = batch["postprocessed_column"][k] + assert (obs_t == postprocessed_col_t / 2.0).all() + + # Check state-in/out and next-obs values. + if j > 0: + next_obs_t_m_1 = batch["new_obs"][k - 1] + # state_out_0_t_m_1 = batch["state_out_0"][k - 1] + # state_out_1_t_m_1 = batch["state_out_1"][k - 1] + # Always same trajectory as for t-1. + assert batch["unroll_id"][k - 1] == batch["unroll_id"][k] + assert (obs_t == next_obs_t_m_1).all() + # assert (state_in_0 == state_out_0_t_m_1).all()) + # assert (state_in_1 == state_out_1_t_m_1).all()) + # Check initial 0-internal states. + elif ts == 0: + assert (state_in_0 == 0.0).all() + assert (state_in_1 == 0.0).all() + + for j in range(seq_len, max_seq_len): + k = cursor + j + obs_t = batch["obs"][k] + a_t = batch["actions"][k] + r_t = batch["rewards"][k] + assert (obs_t == 0.0).all() + assert (a_t == 0.0).all() + assert (r_t == 0.0).all() + + cursor += max_seq_len if __name__ == "__main__": diff --git a/rllib/examples/batch_norm_model.py b/rllib/examples/batch_norm_model.py index 5159a166f..adf6f147d 100644 --- a/rllib/examples/batch_norm_model.py +++ b/rllib/examples/batch_norm_model.py @@ -22,7 +22,7 @@ parser.add_argument("--stop-reward", type=float, default=150) if __name__ == "__main__": args = parser.parse_args() - ray.init(local_mode=True) + ray.init() ModelCatalog.register_custom_model( "bn_model", TorchBatchNormModel if args.torch else BatchNormModel) diff --git a/rllib/examples/cartpole_lstm.py b/rllib/examples/cartpole_lstm.py index 2ff611eb5..c290c29e9 100644 --- a/rllib/examples/cartpole_lstm.py +++ b/rllib/examples/cartpole_lstm.py @@ -50,7 +50,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop) + results = tune.run(args.run, config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/centralized_critic_2.py b/rllib/examples/centralized_critic_2.py index 86337df9e..2644d245a 100644 --- a/rllib/examples/centralized_critic_2.py +++ b/rllib/examples/centralized_critic_2.py @@ -108,7 +108,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run("PPO", config=config, stop=stop) + results = tune.run("PPO", config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/complex_struct_space.py b/rllib/examples/complex_struct_space.py index a7563d824..927c951ec 100644 --- a/rllib/examples/complex_struct_space.py +++ b/rllib/examples/complex_struct_space.py @@ -9,7 +9,6 @@ For PyTorch / TF eager mode, use the --torch and --eager flags. import argparse -import ray from ray import tune from ray.rllib.models import ModelCatalog from ray.rllib.examples.env.simple_rpg import SimpleRPG @@ -21,25 +20,25 @@ parser.add_argument( "--framework", choices=["tf", "tfe", "torch"], default="tf") if __name__ == "__main__": - ray.init(local_mode=True) args = parser.parse_args() if args.framework == "torch": ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel) else: ModelCatalog.register_custom_model("my_model", CustomTFRPGModel) - tune.run( - "PG", - stop={ - "timesteps_total": 1, + + config = { + "framework": args.framework, + "env": SimpleRPG, + "rollout_fragment_length": 1, + "train_batch_size": 2, + "num_workers": 0, + "model": { + "custom_model": "my_model", }, - config={ - "framework": args.framework, - "env": SimpleRPG, - "rollout_fragment_length": 1, - "train_batch_size": 2, - "num_workers": 0, - "model": { - "custom_model": "my_model", - }, - }, - ) + } + + stop = { + "timesteps_total": 1, + } + + tune.run("PG", config=config, stop=stop, verbose=1) diff --git a/rllib/examples/custom_rnn_model.py b/rllib/examples/custom_rnn_model.py index a414a2486..cf87c5c22 100644 --- a/rllib/examples/custom_rnn_model.py +++ b/rllib/examples/custom_rnn_model.py @@ -55,7 +55,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop) + results = tune.run(args.run, config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py index 0d2adcba8..c14d49951 100644 --- a/rllib/examples/env/debug_counter_env.py +++ b/rllib/examples/env/debug_counter_env.py @@ -29,7 +29,7 @@ class DebugCounterEnv(gym.Env): class MultiAgentDebugCounterEnv(MultiAgentEnv): def __init__(self, config): self.num_agents = config["num_agents"] - self.p_done = config.get("p_done", 0.02) + self.base_episode_len = config.get("base_episode_len", 103) # Actions are always: # (episodeID, envID) as floats. self.action_space = \ @@ -45,6 +45,7 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv): self.dones = set() def reset(self): + self.timesteps = [0] * self.num_agents self.dones = set() return { i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32) @@ -57,9 +58,8 @@ class MultiAgentDebugCounterEnv(MultiAgentEnv): self.timesteps[i] += 1 obs[i] = np.array([i, action[0], action[1], self.timesteps[i]]) rew[i] = self.timesteps[i] % 3 - done[i] = bool( - np.random.choice( - [True, False], p=[self.p_done, 1.0 - self.p_done])) + done[i] = True if self.timesteps[i] > self.base_episode_len + i \ + else False if done[i]: self.dones.add(i) done["__all__"] = len(self.dones) == self.num_agents diff --git a/rllib/examples/hierarchical_training.py b/rllib/examples/hierarchical_training.py index f6fddeabc..41cad795b 100644 --- a/rllib/examples/hierarchical_training.py +++ b/rllib/examples/hierarchical_training.py @@ -96,7 +96,7 @@ if __name__ == "__main__": "framework": "torch" if args.torch else "tf", } - results = tune.run("PPO", stop=stop, config=config) + results = tune.run("PPO", stop=stop, config=config, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/models/autoregressive_action_dist.py b/rllib/examples/models/autoregressive_action_dist.py index 929a7d782..37caa68d6 100644 --- a/rllib/examples/models/autoregressive_action_dist.py +++ b/rllib/examples/models/autoregressive_action_dist.py @@ -73,7 +73,7 @@ class BinaryAutoregressiveDistribution(ActionDistribution): return a2_dist @staticmethod - def required_model_output_shape(self, model_config): + def required_model_output_shape(action_space, model_config): return 16 # controls model output feature vector size @@ -143,5 +143,5 @@ class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper): return a2_dist @staticmethod - def required_model_output_shape(self, model_config): + def required_model_output_shape(action_space, model_config): return 16 # controls model output feature vector size diff --git a/rllib/examples/models/batch_norm_model.py b/rllib/examples/models/batch_norm_model.py index 5091415ec..7d77ebc07 100644 --- a/rllib/examples/models/batch_norm_model.py +++ b/rllib/examples/models/batch_norm_model.py @@ -180,7 +180,7 @@ class TorchBatchNormModel(TorchModelV2, nn.Module): def forward(self, input_dict, state, seq_lens): # Set the correct train-mode for our hidden module (only important # b/c we have some batch-norm layers). - self._hidden_layers.train(mode=input_dict["is_training"]) + self._hidden_layers.train(mode=input_dict.get("is_training", False)) self._hidden_out = self._hidden_layers(input_dict["obs"]) logits = self._logits(self._hidden_out) return logits, [] diff --git a/rllib/examples/models/rnn_model.py b/rllib/examples/models/rnn_model.py index 4b3d3db9e..5a9dab8ce 100644 --- a/rllib/examples/models/rnn_model.py +++ b/rllib/examples/models/rnn_model.py @@ -1,9 +1,11 @@ +from gym.spaces import Box import numpy as np from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -101,8 +103,18 @@ class TorchRNNModel(TorchRNN, nn.Module): # Holds the current "base" output (before logits layer). self._features = None + # Add state-ins to this model's view. + for i in range(2): + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift=-1, + space=Box(-1.0, 1.0, shape=(self.lstm_state_size,))) + @override(ModelV2) def get_initial_state(self): + # TODO: (sven): Get rid of `get_initial_state` once Trajectory + # View API is supported across all of RLlib. # Place hidden states on same device as model. h = [ self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0), diff --git a/rllib/examples/multi_agent_custom_policy.py b/rllib/examples/multi_agent_custom_policy.py index d0f540549..4a6164848 100644 --- a/rllib/examples/multi_agent_custom_policy.py +++ b/rllib/examples/multi_agent_custom_policy.py @@ -47,24 +47,22 @@ if __name__ == "__main__": "timesteps_total": args.stop_timesteps, } - results = tune.run( - "PG", - stop=stop, - config={ - "env": "multi_agent_cartpole", - "multiagent": { - "policies": { - "pg_policy": (None, obs_space, act_space, { - "framework": "torch" if args.torch else "tf", - }), - "random": (RandomPolicy, obs_space, act_space, {}), - }, - "policy_mapping_fn": ( - lambda agent_id: ["pg_policy", "random"][agent_id % 2]), + config = { + "env": "multi_agent_cartpole", + "multiagent": { + "policies": { + "pg_policy": (None, obs_space, act_space, { + "framework": "torch" if args.torch else "tf", + }), + "random": (RandomPolicy, obs_space, act_space, {}), }, - "framework": "torch" if args.torch else "tf", + "policy_mapping_fn": ( + lambda agent_id: ["pg_policy", "random"][agent_id % 2]), }, - ) + "framework": "torch" if args.torch else "tf", + } + + results = tune.run("PG", config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/nested_action_spaces.py b/rllib/examples/nested_action_spaces.py index 63be59e25..a5d02130e 100644 --- a/rllib/examples/nested_action_spaces.py +++ b/rllib/examples/nested_action_spaces.py @@ -52,7 +52,7 @@ if __name__ == "__main__": "timesteps_total": args.stop_timesteps, } - results = tune.run(args.run, config=config, stop=stop) + results = tune.run(args.run, config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py index 59018b856..44605cbd8 100644 --- a/rllib/examples/policy/episode_env_aware_policy.py +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -1,3 +1,4 @@ +from gym.spaces import Box import numpy as np from ray.rllib.examples.policy.random_policy import RandomPolicy @@ -13,8 +14,7 @@ class EpisodeEnvAwarePolicy(RandomPolicy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.episode_id = None - self.env_id = None + self.state_space = Box(-1.0, 1.0, (1, )) class _fake_model: pass @@ -22,15 +22,25 @@ class EpisodeEnvAwarePolicy(RandomPolicy): self.model = _fake_model() self.model.time_major = True self.model.inference_view_requirements = { + SampleBatch.AGENT_INDEX: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), "env_id": ViewRequirement(), + "t": ViewRequirement(), SampleBatch.OBS: ViewRequirement(), SampleBatch.PREV_ACTIONS: ViewRequirement( SampleBatch.ACTIONS, space=self.action_space, shift=-1), SampleBatch.PREV_REWARDS: ViewRequirement( SampleBatch.REWARDS, shift=-1), } - self.training_view_requirements = dict( + for i in range(2): + self.model.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), shift=-1, space=self.state_space) + self.model.inference_view_requirements[ + "state_out_{}".format(i)] = \ + ViewRequirement(space=self.state_space) + + self.view_requirements = dict( **{ SampleBatch.NEXT_OBS: ViewRequirement( SampleBatch.OBS, shift=1), @@ -50,17 +60,23 @@ class EpisodeEnvAwarePolicy(RandomPolicy): explore=None, timestep=None, **kwargs): - self.episode_id = input_dict[SampleBatch.EPS_ID][0] - self.env_id = input_dict["env_id"][0] - # Always return (episodeID, envID) - return [ - np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"] - ], [], {} + ts = input_dict["t"] + print(ts) + # Always return [episodeID, envID] as actions. + actions = np.array([[ + input_dict[SampleBatch.AGENT_INDEX][i], + input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i] + ] for i, _ in enumerate(input_dict["obs"])]) + states = [ + np.array([[ts[i]] for i in range(len(input_dict["obs"]))]) + for _ in range(2) + ] + return actions, states, {} @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): - sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0 + sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0 return sample_batch diff --git a/rllib/examples/policy/rock_paper_scissors_dummies.py b/rllib/examples/policy/rock_paper_scissors_dummies.py index 2eb051b09..95e3f62a1 100644 --- a/rllib/examples/policy/rock_paper_scissors_dummies.py +++ b/rllib/examples/policy/rock_paper_scissors_dummies.py @@ -1,7 +1,10 @@ +import gym +import numpy as np import random from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors from ray.rllib.policy.policy import Policy +from ray.rllib.policy.view_requirement import ViewRequirement class AlwaysSameHeuristic(Policy): @@ -10,6 +13,12 @@ class AlwaysSameHeuristic(Policy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exploration = self._create_exploration() + self.view_requirements.update({ + "state_in_0": ViewRequirement( + "state_out_0", + shift=-1, + space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32)) + }) def get_initial_state(self): return [ @@ -27,7 +36,10 @@ class AlwaysSameHeuristic(Policy): info_batch=None, episodes=None, **kwargs): - return state_batches[0], state_batches, {} + if self.config["_use_trajectory_view_api"]: + return state_batches[0][0], [s[0] for s in state_batches], {} + else: + return state_batches[0], state_batches, {} class BeatLastHeuristic(Policy): diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index f9a22a596..13144681b 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -38,7 +38,7 @@ def run_same_policy(args, stop): "framework": "torch" if args.torch else "tf", } - results = tune.run("PG", config=config, stop=stop) + results = tune.run("PG", config=config, stop=stop, verbose=1) if args.as_test: # Check vs 0.0 as we are playing a zero-sum game. diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index b87904117..6eed95f4f 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -63,6 +63,8 @@ class ModelV2: SampleBatch.OBS: ViewRequirement(shift=0), } + # TODO: (sven): Get rid of `get_initial_state` once Trajectory + # View API is supported across all of RLlib. @PublicAPI def get_initial_state(self) -> List[np.ndarray]: """Get the initial recurrent state values for the model. diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index 3e37d60bf..f1c962c4b 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -135,25 +135,20 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): activation_fn=None, initializer=torch.nn.init.xavier_uniform_) - self.inference_view_requirements.update( - dict( - **{ - SampleBatch.OBS: ViewRequirement(shift=0), - SampleBatch.PREV_REWARDS: ViewRequirement( - SampleBatch.REWARDS, shift=-1), - SampleBatch.PREV_ACTIONS: ViewRequirement( - SampleBatch.ACTIONS, space=self.action_space, - shift=-1), - })) + # Add prev-a/r to this model's view, if required. + if model_config["lstm_use_prev_action_reward"]: + self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ + ViewRequirement(SampleBatch.REWARDS, shift=-1) + self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ + ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, + shift=-1) + # Add state-ins to this model's view. for i in range(2): self.inference_view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), shift=-1, space=Box(-1.0, 1.0, shape=(self.cell_size,))) - self.inference_view_requirements["state_out_{}".format(i)] = \ - ViewRequirement( - space=Box(-1.0, 1.0, shape=(self.cell_size,))) @override(RecurrentNetwork) def forward(self, input_dict, state, seq_lens): diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 0414216f7..2d9b5d4da 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -5,6 +5,7 @@ import tree from typing import Dict, List, Optional from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_torch @@ -74,7 +75,14 @@ class Policy(metaclass=ABCMeta): # Child classes need to add their specific requirements here (usually # a combination of a Model's inference_view_- and the # Policy's loss function-requirements. - self.training_view_requirements = {} + self.view_requirements = { + SampleBatch.OBS: ViewRequirement(), + SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), + SampleBatch.REWARDS: ViewRequirement(), + SampleBatch.DONES: ViewRequirement(), + SampleBatch.EPS_ID: ViewRequirement(), + SampleBatch.AGENT_INDEX: ViewRequirement(), + } @abstractmethod @DeveloperAPI @@ -255,7 +263,23 @@ class Policy(metaclass=ABCMeta): shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. """ - raise NotImplementedError + # Default implementation just passes obs, prev-a/r, and states on to + # `self.compute_actions()`. + state_batches = [ + s.unsqueeze(0) + if torch and isinstance(s, torch.Tensor) else np.expand_dims(s, 0) + for k, s in input_dict.items() if k[:9] == "state_in_" + ] + return self.compute_actions( + input_dict[SampleBatch.OBS], + state_batches, + prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS), + info_batch=None, + explore=explore, + timestep=timestep, + **kwargs, + ) @DeveloperAPI def compute_log_likelihoods( diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index c0b999974..bcf3664e6 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -35,7 +35,6 @@ def pad_batch_to_sequences_of_same_size( shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, - _use_trajectory_view_api: bool = False, ): """Applies padding to `batch` so it's choppable into same-size sequences. @@ -56,26 +55,7 @@ def pad_batch_to_sequences_of_same_size( feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. - _use_trajectory_view_api (bool): Whether we are using the Trajectory - View API to collect and process samples. """ - if _use_trajectory_view_api: - if batch.time_major is not None: - batch["seq_lens"] = torch.tensor(batch.seq_lens) - t = 0 if batch.time_major else 1 - for col in batch.data.keys(): - # Cut time-dim from states. - if "state_" in col[:6]: - batch[col] = batch[col][t] - # Flatten all other data. - else: - # Cut time-dim at `max_seq_len`. - if batch.time_major: - batch[col] = batch[col][:batch.max_seq_len] - batch[col] = batch[col].reshape((-1, ) + - batch[col].shape[2:]) - return - if batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index c7dc5861d..69d7ce880 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -11,7 +11,6 @@ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size -from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch @@ -31,8 +30,6 @@ logger = logging.getLogger(__name__) class TorchPolicy(Policy): """Template for a PyTorch policy and loss to use with RLlib. - This is similar to TFPolicy, but for PyTorch. - Attributes: observation_space (gym.Space): observation space of the policy. action_space (gym.Space): action space of the policy. @@ -114,14 +111,7 @@ class TorchPolicy(Policy): self.device = torch.device("cpu") self.model = model.to(self.device) # Combine view_requirements for Model and Policy. - self.training_view_requirements = dict( - **{ - SampleBatch.ACTIONS: ViewRequirement( - space=self.action_space, shift=0), - SampleBatch.REWARDS: ViewRequirement(shift=0), - SampleBatch.DONES: ViewRequirement(shift=0), - }, - **self.model.inference_view_requirements) + self.view_requirements.update(self.model.inference_view_requirements) self.exploration = self._create_exploration() self.unwrapped_model = model # used to support DistributedDataParallel @@ -202,9 +192,11 @@ class TorchPolicy(Policy): with torch.no_grad(): # Pass lazy (torch) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) + # Pack internal state inputs into (separate) list. state_batches = [ - input_dict[k] for k in input_dict.keys() if "state_" in k[:6] + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] ] + # Calculate RNN sequence lengths. seq_lens = np.array([1] * len(input_dict["obs"])) \ if state_batches else None @@ -217,7 +209,8 @@ class TorchPolicy(Policy): extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp - return actions, state_out, extra_fetches + return convert_to_non_torch_type((actions, state_out, + extra_fetches)) def _compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep): @@ -342,7 +335,6 @@ class TorchPolicy(Policy): max_seq_len=self.max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, - _use_trajectory_view_api=self.config["_use_trajectory_view_api"], ) train_batch = self._lazy_tensor_dict(postprocessed_batch) @@ -479,7 +471,7 @@ class TorchPolicy(Policy): @DeveloperAPI def get_initial_state(self) -> List[TensorType]: return [ - s.cpu().detach().numpy() for s in self.model.get_initial_state() + s.detach().cpu().numpy() for s in self.model.get_initial_state() ] @override(Policy) diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 3e4426106..308f63e54 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -64,7 +64,7 @@ def build_torch_policy( apply_gradients_fn: Optional[Callable[ [Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, - training_view_requirements_fn: Optional[Callable[[], Dict[ + view_requirements_fn: Optional[Callable[[], Dict[ str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None): """Helper function for creating a torch policy class at runtime. @@ -159,7 +159,7 @@ def build_torch_policy( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. - training_view_requirements_fn (Callable[[], + view_requirements_fn (Callable[[], Dict[str, ViewRequirement]]): An optional callable to retrieve additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): @@ -226,9 +226,8 @@ def build_torch_policy( get_batch_divisibility_req=get_batch_divisibility_req, ) - if callable(training_view_requirements_fn): - self.training_view_requirements.update( - training_view_requirements_fn(self)) + if callable(view_requirements_fn): + self.view_requirements.update(view_requirements_fn(self)) if after_init: after_init(self, obs_space, action_space, config) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index df79b6340..3264b759b 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -29,7 +29,8 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0): + shift: Union[int, List[int]] = 0, + used_for_training: bool = True): """Initializes a ViewRequirement object. Args: @@ -46,8 +47,12 @@ class ViewRequirement: Example: For a view column "obs" in an Atari framestacking fashion, you can set `data_col="obs"` and `shift=[-3, -2, -1, 0]`. + used_for_training (bool): Whether the data will be used for + training. If False, the column will not be copied into the + final train batch. """ self.data_col = data_col self.space = space or gym.spaces.Box( float("-inf"), float("inf"), shape=()) self.shift = shift + self.used_for_training = used_for_training diff --git a/rllib/utils/spaces/repeated.py b/rllib/utils/spaces/repeated.py index 4ba367ef3..cef0cad0b 100644 --- a/rllib/utils/spaces/repeated.py +++ b/rllib/utils/spaces/repeated.py @@ -1,4 +1,3 @@ -import numpy as np import gym from ray.rllib.utils.annotations import PublicAPI @@ -17,14 +16,9 @@ class Repeated(gym.Space): """ def __init__(self, child_space: gym.Space, max_len: int): - self.np_random = np.random.RandomState() + super().__init__() self.child_space = child_space self.max_len = max_len - super().__init__() - - def seed(self, seed=None): - self.np_random = np.random.RandomState() - self.np_random.seed(seed) def sample(self): return [ diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index aa5c30516..5d513f743 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -34,7 +34,7 @@ def convert_to_non_torch_type(stats): def mapping(item): if isinstance(item, torch.Tensor): return item.cpu().item() if len(item.size()) == 0 else \ - item.cpu().detach().numpy() + item.detach().cpu().numpy() else: return item