[RLlib; Docs overhaul] Docstring cleanup: Evaluation (#19783)

This commit is contained in:
Sven Mika 2021-10-29 12:03:56 +02:00 committed by GitHub
parent f2773267c7
commit 9c73871da0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 1059 additions and 705 deletions

View file

@ -5,20 +5,20 @@ import gym
import ray import ray
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, \ from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule EntropyCoeffSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import explained_variance from ray.rllib.utils.tf_ops import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID, LocalOptimizer, ModelGradients PolicyID, LocalOptimizer, ModelGradients
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.evaluation import MultiAgentEpisode
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -31,7 +31,7 @@ def postprocess_advantages(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch, return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode) other_agent_batches, episode)

View file

@ -3,7 +3,7 @@ from typing import Optional, Dict
import ray import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.action_dist import ActionDistribution
@ -30,7 +30,7 @@ def add_advantages(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch, return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode) other_agent_batches, episode)

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env import BaseEnv from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import AgentID, PolicyID from ray.rllib.utils.typing import AgentID, PolicyID
@ -35,28 +35,23 @@ class DefaultCallbacks:
"a class extending rllib.agents.callbacks.DefaultCallbacks") "a class extending rllib.agents.callbacks.DefaultCallbacks")
self.legacy_callbacks = legacy_callbacks_dict or {} self.legacy_callbacks = legacy_callbacks_dict or {}
def on_episode_start(self, def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
*, policies: Dict[PolicyID, Policy], episode: Episode,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
"""Callback run on the rollout worker before each episode starts. """Callback run on the rollout worker before each episode starts.
Args: Args:
worker: Reference to the current rollout worker. worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying base_env: BaseEnv running the episode. The underlying
sub environment objects can be received by calling sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`. `base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy. agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode object which contains episode episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom temporary data, and `episode.custom_metrics` to store custom
metrics for the episode. metrics for the episode.
env_index (EnvID): Obsoleted: The ID of the environment, which the env_index: Obsoleted: The ID of the environment, which the
episode belongs to. episode belongs to.
kwargs: Forward compatibility placeholder. kwargs: Forward compatibility placeholder.
""" """
@ -73,20 +68,19 @@ class DefaultCallbacks:
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None, policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode, episode: Episode,
env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
"""Runs on each episode step. """Runs on each episode step.
Args: Args:
worker (RolloutWorker): Reference to the current rollout worker. worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying base_env (BaseEnv): BaseEnv running the episode. The underlying
sub environment objects can be gotten by calling sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`. `base_env.get_sub_environments()`.
policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id
to policy objects. In single agent mode there will only be a to policy objects. In single agent mode there will only be a
single "default_policy". single "default_policy".
episode (MultiAgentEpisode): Episode object which contains episode episode (Episode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom temporary data, and `episode.custom_metrics` to store custom
metrics for the episode. metrics for the episode.
@ -101,13 +95,8 @@ class DefaultCallbacks:
"episode": episode "episode": episode
}) })
def on_episode_end(self, def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
*, policies: Dict[PolicyID, Policy], episode: Episode,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
"""Runs when an episode is done. """Runs when an episode is done.
@ -119,7 +108,7 @@ class DefaultCallbacks:
policies (Dict[PolicyID, Policy]): Mapping of policy id to policy policies (Dict[PolicyID, Policy]): Mapping of policy id to policy
objects. In single agent mode there will only be a single objects. In single agent mode there will only be a single
"default_policy". "default_policy".
episode (MultiAgentEpisode): Episode object which contains episode episode (Episode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom temporary data, and `episode.custom_metrics` to store custom
metrics for the episode. metrics for the episode.
@ -136,7 +125,7 @@ class DefaultCallbacks:
}) })
def on_postprocess_trajectory( def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, self, *, worker: "RolloutWorker", episode: Episode,
agent_id: AgentID, policy_id: PolicyID, agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None: original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
@ -148,7 +137,7 @@ class DefaultCallbacks:
Args: Args:
worker (RolloutWorker): Reference to the current rollout worker. worker (RolloutWorker): Reference to the current rollout worker.
episode (MultiAgentEpisode): Episode object. episode (Episode): Episode object.
agent_id (str): Id of the current agent. agent_id (str): Id of the current agent.
policy_id (str): Id of the current policy for the agent. policy_id (str): Id of the current policy for the agent.
policies (dict): Mapping of policy id to policy objects. In single policies (dict): Mapping of policy id to policy objects. In single
@ -253,7 +242,7 @@ class MemoryTrackingCallbacks(DefaultCallbacks):
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, episode: Episode,
env_index: Optional[int] = None, env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
snapshot = tracemalloc.take_snapshot() snapshot = tracemalloc.take_snapshot()
@ -311,7 +300,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, episode: Episode,
env_index: Optional[int] = None, env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
for callback in self._callback_list: for callback in self._callback_list:
@ -328,7 +317,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None, policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode, episode: Episode,
env_index: Optional[int] = None, env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
for callback in self._callback_list: for callback in self._callback_list:
@ -345,7 +334,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, episode: Episode,
env_index: Optional[int] = None, env_index: Optional[int] = None,
**kwargs) -> None: **kwargs) -> None:
for callback in self._callback_list: for callback in self._callback_list:
@ -358,7 +347,7 @@ class MultiCallbacks(DefaultCallbacks):
**kwargs) **kwargs)
def on_postprocess_trajectory( def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, self, *, worker: "RolloutWorker", episode: Episode,
agent_id: AgentID, policy_id: PolicyID, agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None: original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional
import ray import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
@ -247,7 +247,7 @@ def preprocess_episode(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) """Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0. When t=0, the resetted obs is paired with action and reward of 0.
""" """

View file

@ -59,7 +59,7 @@ def postprocess_advantages(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy. same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated. object in which the agents operated.
Returns: Returns:

View file

@ -1,6 +1,6 @@
from typing import List, Optional from typing import List, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
@ -10,7 +10,7 @@ def post_process_advantages(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[List[SampleBatch]] = None, other_agent_batches: Optional[List[SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
"""Adds the "advantages" column to `sample_batch`. """Adds the "advantages" column to `sample_batch`.
Args: Args:
@ -18,7 +18,7 @@ def post_process_advantages(
sample_batch (SampleBatch): The actual sample batch to post-process. sample_batch (SampleBatch): The actual sample batch to post-process.
other_agent_batches (Optional[List[SampleBatch]]): Optional list of other_agent_batches (Optional[List[SampleBatch]]): Optional list of
other agents' SampleBatch objects. other agents' SampleBatch objects.
episode (MultiAgentEpisode): The multi-agent episode object, from which episode (Episode): The multi-agent episode object, from which
`sample_batch` was generated. `sample_batch` was generated.
Returns: Returns:

View file

@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Type, Union
from ray.rllib.agents.impala import vtrace_tf as vtrace from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \ from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, choose_optimizer clip_gradients, choose_optimizer
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.tf.tf_action_dist import Categorical
@ -325,7 +325,7 @@ def postprocess_trajectory(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory. """Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent. The trajectory contains only data from one episode and from one agent.
@ -343,7 +343,7 @@ def postprocess_trajectory(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy. same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated. object in which the agents operated.
Returns: Returns:

View file

@ -7,7 +7,7 @@ import logging
from typing import Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
import ray import ray
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
@ -357,7 +357,7 @@ def postprocess_ppo_gae(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch, return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode) other_agent_batches, episode)

View file

@ -16,7 +16,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS PRIO_WEIGHTS
from ray.rllib.agents.sac.sac_tf_model import SACTFModel from ray.rllib.agents.sac.sac_tf_model import SACTFModel
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \ from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
@ -106,7 +106,7 @@ def postprocess_trajectory(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory. """Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent. The trajectory contains only data from one episode and from one agent.
@ -124,7 +124,7 @@ def postprocess_trajectory(
other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy. same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated. object in which the agents operated.
Returns: Returns:

View file

@ -19,7 +19,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.utils import gym_env_creator from ray.rllib.env.utils import gym_env_creator
from ray.rllib.evaluation.collectors.simple_list_collector import \ from ray.rllib.evaluation.collectors.simple_list_collector import \
SimpleListCollector SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.worker_set import WorkerSet
@ -1086,7 +1086,7 @@ class Trainer(Trainable):
full_fetch: bool = False, full_fetch: bool = False,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episode: Optional[MultiAgentEpisode] = None, episode: Optional[Episode] = None,
unsquash_action: Optional[bool] = None, unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None, clip_action: Optional[bool] = None,
@ -1240,7 +1240,7 @@ class Trainer(Trainable):
full_fetch: bool = False, full_fetch: bool = False,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episodes: Optional[List[MultiAgentEpisode]] = None, episodes: Optional[List[Episode]] = None,
unsquash_actions: Optional[bool] = None, unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None, clip_actions: Optional[bool] = None,
# Deprecated. # Deprecated.

View file

@ -1,4 +1,4 @@
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode, MultiAgentEpisode
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.sample_batch_builder import ( from ray.rllib.evaluation.sample_batch_builder import (
SampleBatchBuilder, MultiAgentSampleBatchBuilder) SampleBatchBuilder, MultiAgentSampleBatchBuilder)
@ -17,5 +17,6 @@ __all__ = [
"AsyncSampler", "AsyncSampler",
"compute_advantages", "compute_advantages",
"collect_metrics", "collect_metrics",
"MultiAgentEpisode", "Episode",
"MultiAgentEpisode", # Deprecated -> Use `Episode` instead.
] ]

View file

@ -2,7 +2,7 @@ from abc import abstractmethod, ABCMeta
import logging import logging
from typing import Dict, List, Optional, TYPE_CHECKING, Union from typing import Dict, List, Optional, TYPE_CHECKING, Union
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
@ -57,7 +57,7 @@ class SampleCollector(metaclass=ABCMeta):
self.count_steps_by = count_steps_by self.count_steps_by = count_steps_by
@abstractmethod @abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, def add_init_obs(self, episode: Episode, agent_id: AgentID,
policy_id: PolicyID, t: int, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None: init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector. """Adds an initial obs (after reset) to this collector.
@ -70,7 +70,7 @@ class SampleCollector(metaclass=ABCMeta):
called for that same agent/episode-pair. called for that same agent/episode-pair.
Args: Args:
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we episode (Episode): The Episode, for which we
are adding an Agent's initial observation. are adding an Agent's initial observation.
agent_id (AgentID): Unique id for the agent we are adding agent_id (AgentID): Unique id for the agent we are adding
values for. values for.
@ -126,11 +126,11 @@ class SampleCollector(metaclass=ABCMeta):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def episode_step(self, episode: MultiAgentEpisode) -> None: def episode_step(self, episode: Episode) -> None:
"""Increases the episode step counter (across all agents) by one. """Increases the episode step counter (across all agents) by one.
Args: Args:
episode (MultiAgentEpisode): Episode we are stepping through. episode (Episode): Episode we are stepping through.
Useful for handling counting b/c it is called once across Useful for handling counting b/c it is called once across
all agents that are inside this episode. all agents that are inside this episode.
""" """
@ -200,7 +200,7 @@ class SampleCollector(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def postprocess_episode(self, def postprocess_episode(self,
episode: MultiAgentEpisode, episode: Episode,
is_done: bool = False, is_done: bool = False,
check_dones: bool = False, check_dones: bool = False,
build: bool = False) -> Optional[MultiAgentBatch]: build: bool = False) -> Optional[MultiAgentBatch]:
@ -214,7 +214,7 @@ class SampleCollector(metaclass=ABCMeta):
correctly added to the buffers. correctly added to the buffers.
Args: Args:
episode (MultiAgentEpisode): The Episode object for which episode (Episode): The Episode object for which
to post-process data. to post-process data.
is_done (bool): Whether the given episode is actually terminated is_done (bool): Whether the given episode is actually terminated
(all agents are done OR we hit a hard horizon). If True, the (all agents are done OR we hit a hard horizon). If True, the

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
@ -505,11 +505,11 @@ class SimpleListCollector(SampleCollector):
# Maps episode ID to the (non-built) individual agent steps in this # Maps episode ID to the (non-built) individual agent steps in this
# episode. # episode.
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int) self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
# Maps episode ID to MultiAgentEpisode. # Maps episode ID to Episode.
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {} self.episodes: Dict[EpisodeID, Episode] = {}
@override(SampleCollector) @override(SampleCollector)
def episode_step(self, episode: MultiAgentEpisode) -> None: def episode_step(self, episode: Episode) -> None:
episode_id = episode.episode_id episode_id = episode.episode_id
# In the rase case that an "empty" step is taken at the beginning of # In the rase case that an "empty" step is taken at the beginning of
# the episode (none of the agents has an observation in the obs-dict # the episode (none of the agents has an observation in the obs-dict
@ -550,8 +550,8 @@ class SimpleListCollector(SampleCollector):
if not self.multiple_episodes_in_batch else "")) if not self.multiple_episodes_in_batch else ""))
@override(SampleCollector) @override(SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, def add_init_obs(self, episode: Episode, agent_id: AgentID, env_id: EnvID,
env_id: EnvID, policy_id: PolicyID, t: int, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None: init_obs: TensorType) -> None:
# Make sure our mappings are up to date. # Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id) agent_key = (episode.episode_id, agent_id)
@ -707,7 +707,7 @@ class SimpleListCollector(SampleCollector):
@override(SampleCollector) @override(SampleCollector)
def postprocess_episode( def postprocess_episode(
self, self,
episode: MultiAgentEpisode, episode: Episode,
is_done: bool = False, is_done: bool = False,
check_dones: bool = False, check_dones: bool = False,
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]: build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
@ -834,7 +834,7 @@ class SimpleListCollector(SampleCollector):
if build: if build:
return self._build_multi_agent_batch(episode) return self._build_multi_agent_batch(episode)
def _build_multi_agent_batch(self, episode: MultiAgentEpisode) -> \ def _build_multi_agent_batch(self, episode: Episode) -> \
Union[MultiAgentBatch, SampleBatch]: Union[MultiAgentBatch, SampleBatch]:
ma_batch = {} ma_batch = {}

View file

@ -1,11 +1,11 @@
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import random import random
from typing import List, Dict, Callable, Any, TYPE_CHECKING from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \ from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
@ -19,7 +19,7 @@ if TYPE_CHECKING:
@DeveloperAPI @DeveloperAPI
class MultiAgentEpisode: class Episode:
"""Tracks the current state of a (possibly multi-agent) episode. """Tracks the current state of a (possibly multi-agent) episode.
Attributes: Attributes:
@ -53,15 +53,28 @@ class MultiAgentEpisode:
def __init__( def __init__(
self, self,
policies: PolicyMap, policies: PolicyMap,
policy_mapping_fn: Callable[ policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
[AgentID, "MultiAgentEpisode", "RolloutWorker"], PolicyID], PolicyID],
batch_builder_factory: Callable[[], batch_builder_factory: Callable[[],
"MultiAgentSampleBatchBuilder"], "MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None], extra_batch_callback: Callable[[SampleBatchType], None],
env_id: EnvID, env_id: EnvID,
*, *,
worker: "RolloutWorker" = None, worker: Optional["RolloutWorker"] = None,
): ):
"""Initializes an Episode instance.
Args:
policies: The PolicyMap object (mapping PolicyIDs to Policy
objects) to use for determining, which policy is used for
which agent.
policy_mapping_fn: The mapping function mapping AgentIDs to
PolicyIDs.
batch_builder_factory:
extra_batch_callback:
env_id: The environment's ID in which this episode runs.
worker: The RolloutWorker instance, in which this episode runs.
"""
self.new_batch_builder: Callable[ self.new_batch_builder: Callable[
[], "MultiAgentSampleBatchBuilder"] = batch_builder_factory [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
self.add_extra_batch: Callable[[SampleBatchType], self.add_extra_batch: Callable[[SampleBatchType],
@ -80,9 +93,8 @@ class MultiAgentEpisode:
self.media: Dict[str, Any] = {} self.media: Dict[str, Any] = {}
self.policy_map: PolicyMap = policies self.policy_map: PolicyMap = policies
self._policies = self.policy_map # backward compatibility self._policies = self.policy_map # backward compatibility
self.policy_mapping_fn: Callable[[ self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
AgentID, "MultiAgentEpisode", "RolloutWorker" PolicyID] = policy_mapping_fn
], PolicyID] = policy_mapping_fn
self._next_agent_index: int = 0 self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {} self._agent_to_index: Dict[AgentID, int] = {}
self._agent_to_policy: Dict[AgentID, PolicyID] = {} self._agent_to_policy: Dict[AgentID, PolicyID] = {}
@ -92,21 +104,11 @@ class MultiAgentEpisode:
self._agent_to_last_done: Dict[AgentID, bool] = {} self._agent_to_last_done: Dict[AgentID, bool] = {}
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {} self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {} self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
self._agent_to_last_pi_info: Dict[AgentID, dict] = {} self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {} self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict( self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
list) list)
# TODO: Deprecated.
@property
def _policy_mapping_fn(self):
deprecation_warning(
old="MultiAgentEpisode._policy_mapping_fn",
new="MultiAgentEpisode.policy_mapping_fn",
error=False,
)
return self.policy_mapping_fn
@DeveloperAPI @DeveloperAPI
def soft_reset(self) -> None: def soft_reset(self) -> None:
"""Clears rewards and metrics, but retains RNN and other state. """Clears rewards and metrics, but retains RNN and other state.
@ -126,15 +128,17 @@ class MultiAgentEpisode:
If the agent is new, the policy mapping fn will be called to bind the If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the entire episode (even if the agent to a policy for the duration of the entire episode (even if the
policy_mapping_fn is changed in the meantime). policy_mapping_fn is changed in the meantime!).
Args: Args:
agent_id (AgentID): The agent ID to lookup the policy ID for. agent_id: The agent ID to lookup the policy ID for.
Returns: Returns:
PolicyID: The policy ID for the specified agent. The policy ID for the specified agent.
""" """
# Perform a new policy_mapping_fn lookup and bind AgentID for the
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy: if agent_id not in self._agent_to_policy:
# Try new API: pass in agent_id and episode as named args. # Try new API: pass in agent_id and episode as named args.
# New signature should be: (agent_id, episode, worker, **kwargs) # New signature should be: (agent_id, episode, worker, **kwargs)
@ -153,8 +157,11 @@ class MultiAgentEpisode:
self.policy_mapping_fn(agent_id) self.policy_mapping_fn(agent_id)
else: else:
raise e raise e
# Use already determined PolicyID.
else: else:
policy_id = self._agent_to_policy[agent_id] policy_id = self._agent_to_policy[agent_id]
# PolicyID not found in policy map -> Error.
if policy_id not in self.policy_map: if policy_id not in self.policy_map:
raise KeyError("policy_mapping_fn returned invalid policy id " raise KeyError("policy_mapping_fn returned invalid policy id "
f"'{policy_id}'!") f"'{policy_id}'!")
@ -162,33 +169,70 @@ class MultiAgentEpisode:
@DeveloperAPI @DeveloperAPI
def last_observation_for( def last_observation_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType: self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last observation for the specified agent.""" """Returns the last observation for the specified AgentID.
Args:
agent_id: The agent's ID to get the last observation for.
Returns:
Last observation the specified AgentID has seen. None in case
the agent has never made any observations in the episode.
"""
return self._agent_to_last_obs.get(agent_id) return self._agent_to_last_obs.get(agent_id)
@DeveloperAPI @DeveloperAPI
def last_raw_obs_for(self, def last_raw_obs_for(
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType: self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last un-preprocessed obs for the specified agent.""" """Returns the last un-preprocessed obs for the specified AgentID.
Args:
agent_id: The agent's ID to get the last un-preprocessed
observation for.
Returns:
Last un-preprocessed observation the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_raw_obs.get(agent_id) return self._agent_to_last_raw_obs.get(agent_id)
@DeveloperAPI @DeveloperAPI
def last_info_for(self, def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvInfoDict: ) -> Optional[EnvInfoDict]:
"""Returns the last info for the specified agent.""" """Returns the last info for the specified AgentID.
Args:
agent_id: The agent's ID to get the last info for.
Returns:
Last info dict the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_info.get(agent_id) return self._agent_to_last_info.get(agent_id)
@DeveloperAPI @DeveloperAPI
def last_action_for(self, def last_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the last action for the specified agent, or zeros.""" """Returns the last action for the specified AgentID, or zeros.
The "last" action is the most recent one taken by the agent.
Args:
agent_id: The agent's ID to get the last action for.
Returns:
Last action the specified AgentID has executed.
Zeros in case the agent has never performed any actions in the
episode.
"""
# Agent has already taken at least one action in the episode.
if agent_id in self._agent_to_last_action: if agent_id in self._agent_to_last_action:
return flatten_to_single_ndarray( return flatten_to_single_ndarray(
self._agent_to_last_action[agent_id]) self._agent_to_last_action[agent_id])
# Agent has not acted yet, return all zeros.
else: else:
policy_id = self.policy_for(agent_id) policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id] policy = self.policy_map[policy_id]
@ -200,29 +244,84 @@ class MultiAgentEpisode:
@DeveloperAPI @DeveloperAPI
def prev_action_for(self, def prev_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the previous action for the specified agent.""" """Returns the previous action for the specified agent, or zeros.
The "previous" action is the one taken one timestep before the
most recent action taken by the agent.
Args:
agent_id: The agent's ID to get the previous action for.
Returns:
Previous action the specified AgentID has executed.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
# We are at t > 1 -> There has been a previous action by this agent.
if agent_id in self._agent_to_prev_action: if agent_id in self._agent_to_prev_action:
return flatten_to_single_ndarray( return flatten_to_single_ndarray(
self._agent_to_prev_action[agent_id]) self._agent_to_prev_action[agent_id])
# We're at t <= 1, so return all zeros.
else: else:
# We're at t=0, so return all zeros.
return np.zeros_like(self.last_action_for(agent_id)) return np.zeros_like(self.last_action_for(agent_id))
@DeveloperAPI @DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float: def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent.""" """Returns the last reward for the specified agent, or zero.
The "last" reward is the one received most recently by the agent.
Args:
agent_id: The agent's ID to get the last reward for.
Returns:
Last reward for the the specified AgentID.
Zero in case the agent has never performed any actions
(and thus received rewards) in the episode.
"""
history = self._agent_reward_history[agent_id] history = self._agent_reward_history[agent_id]
# We are at t > 0 -> Return previously received reward.
if len(history) >= 1:
return history[-1]
# We're at t=0, so there is no previous reward, just return zero.
else:
return 0.0
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent, or zero.
The "previous" reward is the one received one timestep before the
most recently received reward of the agent.
Args:
agent_id: The agent's ID to get the previous reward for.
Returns:
Previous reward for the the specified AgentID.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 1 -> Return reward prior to most recent (last) one.
if len(history) >= 2: if len(history) >= 2:
return history[-2] return history[-2]
# We're at t <= 1, so there is no previous reward, just return zero.
else: else:
# We're at t=0, so there is no previous reward, just return zero.
return 0.0 return 0.0
@DeveloperAPI @DeveloperAPI
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]: def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
"""Returns the last RNN state for the specified agent.""" """Returns the last RNN state for the specified agent.
Args:
agent_id: The agent's ID to get the most recent RNN state for.
Returns:
Most recent RNN state of the the specified AgentID.
"""
if agent_id not in self._agent_to_rnn_state: if agent_id not in self._agent_to_rnn_state:
policy_id = self.policy_for(agent_id) policy_id = self.policy_for(agent_id)
@ -232,24 +331,44 @@ class MultiAgentEpisode:
@DeveloperAPI @DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool: def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
"""Returns the last done flag received for the specified agent.""" """Returns the last done flag for the specified AgentID.
Args:
agent_id: The agent's ID to get the last done flag for.
Returns:
Last done flag for the specified AgentID.
"""
if agent_id not in self._agent_to_last_done: if agent_id not in self._agent_to_last_done:
self._agent_to_last_done[agent_id] = False self._agent_to_last_done[agent_id] = False
return self._agent_to_last_done[agent_id] return self._agent_to_last_done[agent_id]
@DeveloperAPI @DeveloperAPI
def last_pi_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> dict: def last_extra_action_outs_for(
"""Returns the last info object for the specified agent.""" self,
agent_id: AgentID = _DUMMY_AGENT_ID,
) -> dict:
"""Returns the last extra-action outputs for the specified agent.
return self._agent_to_last_pi_info[agent_id] This data is returned by a call to
`Policy.compute_actions_from_input_dict` as the 3rd return value
(1st return value = action; 2nd return value = RNN state outs).
Args:
agent_id: The agent's ID to get the last extra-action outs for.
Returns:
The last extra-action outs for the specified AgentID.
"""
return self._agent_to_last_extra_action_outs[agent_id]
@DeveloperAPI @DeveloperAPI
def get_agents(self) -> List[AgentID]: def get_agents(self) -> List[AgentID]:
"""Returns list of agent IDs that have appeared in this episode. """Returns list of agent IDs that have appeared in this episode.
Returns: Returns:
List[AgentID]: The list of all agents that have appeared so The list of all agent IDs that have appeared so far in this
far in this episode. episode.
""" """
return list(self._agent_to_index.keys()) return list(self._agent_to_index.keys())
@ -282,11 +401,31 @@ class MultiAgentEpisode:
self._agent_to_last_action[agent_id] self._agent_to_last_action[agent_id]
self._agent_to_last_action[agent_id] = action self._agent_to_last_action[agent_id] = action
def _set_last_pi_info(self, agent_id, pi_info): def _set_last_extra_action_outs(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info self._agent_to_last_extra_action_outs[agent_id] = pi_info
def _agent_index(self, agent_id): def _agent_index(self, agent_id):
if agent_id not in self._agent_to_index: if agent_id not in self._agent_to_index:
self._agent_to_index[agent_id] = self._next_agent_index self._agent_to_index[agent_id] = self._next_agent_index
self._next_agent_index += 1 self._next_agent_index += 1
return self._agent_to_index[agent_id] return self._agent_to_index[agent_id]
@property
def _policy_mapping_fn(self):
deprecation_warning(
old="Episode._policy_mapping_fn",
new="Episode.policy_mapping_fn",
error=False,
)
return self.policy_mapping_fn
@Deprecated(new="Episode.last_extra_action_outs_for", error=False)
def last_pi_info_for(self, *args, **kwargs):
return self.last_extra_action_outs_for(*args, **kwargs)
# Backward compatibility. The name Episode implies that there is
# also a (single agent?) Episode.
@Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
class MultiAgentEpisode(Episode):
pass

View file

@ -1,14 +1,13 @@
import collections
import logging import logging
import numpy as np import numpy as np
import collections from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
import ray import ray
from ray import ObjectRef from ray import ObjectRef
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
@ -18,6 +17,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length",
"episode_reward",
"agent_rewards",
"custom_metrics",
"perf_stats",
"hist_data",
"media",
])
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})
def extract_stats(stats: Dict, key: str) -> Dict[str, Any]: def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
if key in stats: if key in stats:

View file

@ -2,7 +2,7 @@ from typing import Dict
from ray.rllib.env import BaseEnv from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.utils.framework import TensorType from ray.rllib.utils.framework import TensorType
from ray.rllib.utils.typing import AgentID, PolicyID from ray.rllib.utils.typing import AgentID, PolicyID
@ -22,7 +22,7 @@ class ObservationFunction:
def __call__(self, agent_obs: Dict[AgentID, TensorType], def __call__(self, agent_obs: Dict[AgentID, TensorType],
worker: RolloutWorker, base_env: BaseEnv, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, policies: Dict[PolicyID, Policy], episode: Episode,
**kw) -> Dict[AgentID, TensorType]: **kw) -> Dict[AgentID, TensorType]:
"""Callback run on each environment step to observe the environment. """Callback run on each environment step to observe the environment.
@ -45,7 +45,7 @@ class ObservationFunction:
retrieved by calling `base_env.get_sub_environments()`. retrieved by calling `base_env.get_sub_environments()`.
policies (dict): Mapping of policy id to policy objects. In single policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy. agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode state object. episode (Episode): Episode state object.
kwargs: Forward compatibility placeholder. kwargs: Forward compatibility placeholder.
Returns: Returns:

View file

@ -2,7 +2,7 @@ import numpy as np
import scipy.signal import scipy.signal
from typing import Dict, Optional from typing import Dict, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
@ -135,7 +135,7 @@ def compute_gae_for_sample_batch(
policy: Policy, policy: Policy,
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: episode: Optional[Episode] = None) -> SampleBatch:
"""Adds GAE (generalized advantage estimations) to a trajectory. """Adds GAE (generalized advantage estimations) to a trajectory.
The trajectory contains only data from one episode and from one agent. The trajectory contains only data from one episode and from one agent.
@ -153,7 +153,7 @@ def compute_gae_for_sample_batch(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy. same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated. object in which the agents operated.
Returns: Returns:

View file

@ -1,13 +0,0 @@
import collections
# Define this in its own file, see #5125
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length",
"episode_reward",
"agent_rewards",
"custom_metrics",
"perf_stats",
"hist_data",
"media",
])
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ import numpy as np
from typing import List, Any, Dict, Optional, TYPE_CHECKING from typing import List, Any, Dict, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
@ -152,15 +152,15 @@ class MultiAgentSampleBatchBuilder:
self.agent_builders[agent_id].add_values(**values) self.agent_builders[agent_id].add_values(**values)
def postprocess_batch_so_far( def postprocess_batch_so_far(self,
self, episode: Optional[MultiAgentEpisode] = None) -> None: episode: Optional[Episode] = None) -> None:
"""Apply policy postprocessors to any unprocessed rows. """Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state. builders, clearing per-agent state.
Args: Args:
episode (Optional[MultiAgentEpisode]): The Episode object that episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object. holds this MultiAgentBatchBuilder object.
""" """
@ -234,15 +234,15 @@ class MultiAgentSampleBatchBuilder:
"Alternatively, set no_done_at_end=True to allow this.") "Alternatively, set no_done_at_end=True to allow this.")
@DeveloperAPI @DeveloperAPI
def build_and_reset(self, episode: Optional[MultiAgentEpisode] = None def build_and_reset(self,
) -> MultiAgentBatch: episode: Optional[Episode] = None) -> MultiAgentBatch:
"""Returns the accumulated sample batches for each policy. """Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset. postprocessor. The internal state of this builder will be reset.
Args: Args:
episode (Optional[MultiAgentEpisode]): The Episode object that episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object or None. holds this MultiAgentBatchBuilder object or None.
Returns: Returns:

View file

@ -14,8 +14,8 @@ from ray.rllib.evaluation.collectors.sample_collector import \
SampleCollector SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import \ from ray.rllib.evaluation.collectors.simple_list_collector import \
SimpleListCollector SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \ from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder MultiAgentSampleBatchBuilder
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
@ -110,16 +110,45 @@ class SamplerInput(InputReader, metaclass=ABCMeta):
@abstractmethod @abstractmethod
@DeveloperAPI @DeveloperAPI
def get_data(self) -> SampleBatchType: def get_data(self) -> SampleBatchType:
"""Called by `self.next()` to return the next batch of data.
Override this in child classes.
Returns:
The next batch of data.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@DeveloperAPI @DeveloperAPI
def get_metrics(self) -> List[RolloutMetrics]: def get_metrics(self) -> List[RolloutMetrics]:
"""Returns list of episode metrics since the last call to this method.
The list will contain one RolloutMetrics object per completed episode.
Returns:
List of RolloutMetrics objects, one per completed episode since
the last call to this method.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@DeveloperAPI @DeveloperAPI
def get_extra_batches(self) -> List[SampleBatchType]: def get_extra_batches(self) -> List[SampleBatchType]:
"""Returns list of extra batches since the last call to this method.
The list will contain all SampleBatches or
MultiAgentBatches that the user has provided thus-far. Users can
add these "extra batches" to an episode by calling the episode's
`add_extra_batch([SampleBatchType])` method. This can be done from
inside an overridden `Policy.compute_actions_from_input_dict(...,
episodes)` or from a custom callback's `on_episode_[start|step|end]()`
methods.
Returns:
List of SamplesBatches or MultiAgentBatches provided thus-far by
the user since the last call to this method.
"""
raise NotImplementedError raise NotImplementedError
@ -143,7 +172,7 @@ class SyncSampler(SamplerInput):
clip_actions: bool = False, clip_actions: bool = False,
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None, observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None, sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False, render: bool = False,
# Obsolete. # Obsolete.
@ -153,41 +182,44 @@ class SyncSampler(SamplerInput):
obs_filters=None, obs_filters=None,
tf_sess=None, tf_sess=None,
): ):
"""Initializes a SyncSampler object. """Initializes a SyncSampler instance.
Args: Args:
worker (RolloutWorker): The RolloutWorker that will use this worker: The RolloutWorker that will use this Sampler for sampling.
Sampler for sampling. env: Any Env object. Will be converted into an RLlib BaseEnv.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv. clip_rewards: True for +/-1.0 clipping,
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no actual float value for +/- value clipping. False for no
clipping. clipping.
rollout_fragment_length (int): The length of a fragment to collect rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting before building a SampleBatch from the data and resetting
the SampleBatchBuilder object. the SampleBatchBuilder object.
callbacks (Callbacks): The Callbacks object to use when episode count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
callbacks: The Callbacks object to use when episode
events happen during rollout. events happen during rollout.
horizon (Optional[int]): Hard-reset the Env horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size. exactly `rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the normalize_actions: Whether to normalize actions to the
action space's bounds. action space's bounds.
clip_actions (bool): Whether to clip actions according to the clip_actions: Whether to clip actions according to the
given action_space's bounds. given action_space's bounds.
soft_horizon (bool): If True, calculate bootstrapped values as if soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment episode had ended, but don't physically reset the environment
when the horizon is hit. when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False. episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional observation_fn: Optional multi-agent observation func to use for
multi-agent observation func to use for preprocessing preprocessing observations.
observations. sample_collector_class: An optional Samplecollector sub-class to
sample_collector_class (Optional[Type[SampleCollector]]): An use to collect, store, and retrieve environment-, model-,
optional Samplecollector sub-class to use to collect, store, and sampler data.
and retrieve environment-, model-, and sampler data. render: Whether to try to render the environment after each step.
render (bool): Whether to try to render the environment after each
step.
""" """
# All of the following arguments are deprecated. They will instead be # All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`. # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
@ -262,8 +294,9 @@ class SyncSampler(SamplerInput):
class AsyncSampler(threading.Thread, SamplerInput): class AsyncSampler(threading.Thread, SamplerInput):
"""Async SamplerInput that collects experiences in thread and queues them. """Async SamplerInput that collects experiences in thread and queues them.
Once started, experiences are continuously collected and put into a Queue, Once started, experiences are continuously collected in the background
from where they can be unqueued by the caller of `get_data()`. and put into a Queue, from where they can be unqueued by the caller
of `get_data()`.
""" """
def __init__( def __init__(
@ -279,12 +312,12 @@ class AsyncSampler(threading.Thread, SamplerInput):
multiple_episodes_in_batch: bool = False, multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True, normalize_actions: bool = True,
clip_actions: bool = False, clip_actions: bool = False,
blackhole_outputs: bool = False,
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
observation_fn: Optional["ObservationFunction"] = None, observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None, sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False, render: bool = False,
blackhole_outputs: bool = False,
# Obsolete. # Obsolete.
policies=None, policies=None,
policy_mapping_fn=None, policy_mapping_fn=None,
@ -292,45 +325,44 @@ class AsyncSampler(threading.Thread, SamplerInput):
obs_filters=None, obs_filters=None,
tf_sess=None, tf_sess=None,
): ):
"""Initializes a AsyncSampler object. """Initializes an AsyncSampler instance.
Args: Args:
worker (RolloutWorker): The RolloutWorker that will use this worker: The RolloutWorker that will use this Sampler for sampling.
Sampler for sampling. env: Any Env object. Will be converted into an RLlib BaseEnv.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv. clip_rewards: True for +/-1.0 clipping,
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no actual float value for +/- value clipping. False for no
clipping. clipping.
rollout_fragment_length (int): The length of a fragment to collect rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting before building a SampleBatch from the data and resetting
the SampleBatchBuilder object. the SampleBatchBuilder object.
count_steps_by (str): Either "env_steps" or "agent_steps". count_steps_by: One of "env_steps" (default) or "agent_steps".
Refers to the unit of `rollout_fragment_length`. Use "agent_steps", if you want rollout lengths to be counted
callbacks (Callbacks): The Callbacks object to use when episode by individual agent steps. In a multi-agent env,
events happen during rollout. a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
horizon: Hard-reset the Env after this many timesteps. horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size. exactly `rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the normalize_actions: Whether to normalize actions to the
action space's bounds. action space's bounds.
clip_actions (bool): Whether to clip actions according to the clip_actions: Whether to clip actions according to the
given action_space's bounds. given action_space's bounds.
blackhole_outputs (bool): Whether to collect samples, but then blackhole_outputs: Whether to collect samples, but then
not further process or store them (throw away all samples). not further process or store them (throw away all samples).
soft_horizon (bool): If True, calculate bootstrapped values as if soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment episode had ended, but don't physically reset the environment
when the horizon is hit. when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False. episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional observation_fn: Optional multi-agent observation func to use for
multi-agent observation func to use for preprocessing preprocessing observations.
observations. sample_collector_class: An optional SampleCollector sub-class to
sample_collector_class (Optional[Type[SampleCollector]]): An use to collect, store, and retrieve environment-, model-,
optional Samplecollector sub-class to use to collect, store, and sampler data.
and retrieve environment-, model-, and sampler data. render: Whether to try to render the environment after each step.
render (bool): Whether to try to render the environment after each
step.
""" """
# All of the following arguments are deprecated. They will instead be # All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`. # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
@ -467,32 +499,32 @@ def _env_runner(
"""This implements the common experience collection logic. """This implements the common experience collection logic.
Args: Args:
worker (RolloutWorker): Reference to the current rollout worker. worker: Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv. base_env: Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to. extra_batch_callback: function to send extra batch data to.
horizon: Horizon of the episode. horizon: Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size. `rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the action normalize_actions: Whether to normalize actions to the action
space's bounds. space's bounds.
clip_actions (bool): Whether to clip actions to the space range. clip_actions: Whether to clip actions to the space range.
callbacks (DefaultCallbacks): User callbacks to run on episode events. callbacks: User callbacks to run on episode events.
perf_stats (_PerfStats): Record perf stats into this object. perf_stats: Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit. environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode no_done_at_end: Ignore the done=True at the end of the episode
and instead record done=False. and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent observation_fn: Optional multi-agent
observation func to use for preprocessing observations. observation func to use for preprocessing observations.
sample_collector (Optional[SampleCollector]): An optional sample_collector: An optional
SampleCollector object to use. SampleCollector object to use.
render (bool): Whether to try to render the environment after each render: Whether to try to render the environment after each
step. step.
Yields: Yields:
rollout (SampleBatch): Object containing state, action, reward, Object containing state, action, reward, terminal condition,
terminal condition, and other fields as dictated by `policy`. and other fields as dictated by `policy`.
""" """
# May be populated with used for image rendering # May be populated with used for image rendering
@ -546,7 +578,7 @@ def _env_runner(
return None return None
def new_episode(env_id): def new_episode(env_id):
episode = MultiAgentEpisode( episode = Episode(
worker.policy_map, worker.policy_map,
worker.policy_mapping_fn, worker.policy_mapping_fn,
get_batch_builder, get_batch_builder,
@ -576,7 +608,7 @@ def _env_runner(
) )
return episode return episode
active_episodes: Dict[EnvID, MultiAgentEpisode] = \ active_episodes: Dict[EnvID, Episode] = \
NewEpisodeDefaultDict(new_episode) NewEpisodeDefaultDict(new_episode)
while True: while True:
@ -684,7 +716,7 @@ def _process_observations(
*, *,
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
active_episodes: Dict[EnvID, MultiAgentEpisode], active_episodes: Dict[EnvID, Episode],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]], rewards: Dict[EnvID, Dict[AgentID, float]],
dones: Dict[EnvID, Dict[AgentID, bool]], dones: Dict[EnvID, Dict[AgentID, bool]],
@ -701,38 +733,37 @@ def _process_observations(
"""Record new data from the environment and prepare for policy evaluation. """Record new data from the environment and prepare for policy evaluation.
Args: Args:
worker (RolloutWorker): Reference to the current rollout worker. worker: Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv. base_env: Env implementing BaseEnv.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from active_episodes: Mapping from
episode ID to currently ongoing MultiAgentEpisode object. episode ID to currently ongoing Episode object.
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
-> unfiltered observation tensor, returned by a `BaseEnv.poll()` -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
call. call.
rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards: Doubly keyed dict of env-ids -> agent ids ->
rewards tensor, returned by a `BaseEnv.poll()` call. rewards tensor, returned by a `BaseEnv.poll()` call.
dones (dict): Doubly keyed dict of env-ids -> agent ids -> dones: Doubly keyed dict of env-ids -> agent ids ->
boolean done flags, returned by a `BaseEnv.poll()` call. boolean done flags, returned by a `BaseEnv.poll()` call.
infos (dict): Doubly keyed dict of env-ids -> agent ids -> infos: Doubly keyed dict of env-ids -> agent ids ->
info dicts, returned by a `BaseEnv.poll()` call. info dicts, returned by a `BaseEnv.poll()` call.
horizon (int): Horizon of the episode. horizon: Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size. `rollout_fragment_length` in size.
callbacks (DefaultCallbacks): User callbacks to run on episode events. callbacks: User callbacks to run on episode events.
soft_horizon (bool): Calculate rewards but don't reset the soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit. environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode no_done_at_end: Ignore the done=True at the end of the episode
and instead record done=False. and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent observation_fn: Optional multi-agent
observation func to use for preprocessing observations. observation func to use for preprocessing observations.
sample_collector (SampleCollector): The SampleCollector object sample_collector: The SampleCollector object
used to store and retrieve environment samples. used to store and retrieve environment samples.
Returns: Returns:
Tuple: Tuple consisting of 1) active_envs: Set of non-terminated env ids.
- active_envs: Set of non-terminated env ids. 2) to_eval: Map of policy_id to list of agent PolicyEvalData.
- to_eval: Map of policy_id to list of agent PolicyEvalData. 3) outputs: List of metrics and samples to return from the sampler.
- outputs: List of metrics and samples to return from the sampler.
""" """
# Output objects. # Output objects.
@ -744,7 +775,7 @@ def _process_observations(
# types: EnvID, Dict[AgentID, EnvObsType] # types: EnvID, Dict[AgentID, EnvObsType]
for env_id, all_agents_obs in unfiltered_obs.items(): for env_id, all_agents_obs in unfiltered_obs.items():
is_new_episode: bool = env_id not in active_episodes is_new_episode: bool = env_id not in active_episodes
episode: MultiAgentEpisode = active_episodes[env_id] episode: Episode = active_episodes[env_id]
if not is_new_episode: if not is_new_episode:
sample_collector.episode_step(episode) sample_collector.episode_step(episode)
@ -854,9 +885,11 @@ def _process_observations(
# Next observation. # Next observation.
SampleBatch.NEXT_OBS: filtered_obs, SampleBatch.NEXT_OBS: filtered_obs,
} }
# Add extra-action-fetches to collectors. # Add extra-action-fetches (policy-inference infos) to
# collectors.
pol = worker.policy_map[policy_id] pol = worker.policy_map[policy_id]
for key, value in episode.last_pi_info_for(agent_id).items(): for key, value in episode.last_extra_action_outs_for(
agent_id).items():
if key in pol.view_requirements: if key in pol.view_requirements:
values_dict[key] = value values_dict[key] = value
# Env infos for this agent. # Env infos for this agent.
@ -947,7 +980,7 @@ def _process_observations(
# Creates a new episode if this is not async return. # Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll. # If reset is async, we will get its result in some future poll.
elif resetted_obs != ASYNC_RESET_RETURN: elif resetted_obs != ASYNC_RESET_RETURN:
new_episode: MultiAgentEpisode = active_episodes[env_id] new_episode: Episode = active_episodes[env_id]
if observation_fn: if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs, agent_obs=resetted_obs,
@ -995,21 +1028,20 @@ def _do_policy_eval(
to_eval: Dict[PolicyID, List[PolicyEvalData]], to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: PolicyMap, policies: PolicyMap,
sample_collector, sample_collector,
active_episodes: Dict[EnvID, MultiAgentEpisode], active_episodes: Dict[EnvID, Episode],
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action. """Call compute_actions on collected episode/model data to get next action.
Args: Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
IDs to lists of PolicyEvalData objects (items in these lists will (items in these lists will be the batch's items for the model
be the batch's items for the model forward pass). forward pass).
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy policies: Mapping from policy ID to Policy obj.
obj. sample_collector: The SampleCollector object to use.
sample_collector (SampleCollector): The SampleCollector object to use. active_episodes: Mapping of EnvID to its currently active episode.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping of
Returns: Returns:
eval_results: dict of policy to compute_action() outputs. Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
""" """
eval_results: Dict[PolicyID, TensorStructType] = {} eval_results: Dict[PolicyID, TensorStructType] = {}
@ -1052,7 +1084,7 @@ def _process_policy_eval_results(
to_eval: Dict[PolicyID, List[PolicyEvalData]], to_eval: Dict[PolicyID, List[PolicyEvalData]],
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch, eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
dict]], dict]],
active_episodes: Dict[EnvID, MultiAgentEpisode], active_episodes: Dict[EnvID, Episode],
active_envs: Set[int], active_envs: Set[int],
off_policy_actions: MultiEnvDict, off_policy_actions: MultiEnvDict,
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
@ -1065,24 +1097,22 @@ def _process_policy_eval_results(
returns replies to send back to agents in the env. returns replies to send back to agents in the env.
Args: Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
to lists of PolicyEvalData objects. eval_results: Mapping of policy IDs to list of
eval_results (Dict[PolicyID, List]): Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts. actions, rnn-out states, extra-action-fetches dicts.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from active_episodes: Mapping from episode ID to currently ongoing
episode ID to currently ongoing MultiAgentEpisode object. Episode object.
active_envs (Set[int]): Set of non-terminated env ids. active_envs: Set of non-terminated env ids.
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids -> off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call. off-policy-action, returned by a `BaseEnv.poll()` call.
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy. policies: Mapping from policy ID to Policy.
normalize_actions (bool): Whether to normalize actions to the action normalize_actions: Whether to normalize actions to the action
space's bounds. space's bounds.
clip_actions (bool): Whether to clip actions to the action space's clip_actions: Whether to clip actions to the action space's bounds.
bounds.
Returns: Returns:
actions_to_send: Nested dict of env id -> agent id -> actions to be Nested dict of env id -> agent id -> actions to be sent to
sent to Env (np.ndarrays). Env (np.ndarrays).
""" """
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \ actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
@ -1098,7 +1128,7 @@ def _process_policy_eval_results(
actions = convert_to_numpy(actions) actions = convert_to_numpy(actions)
rnn_out_cols: StateBatch = eval_results[policy_id][1] rnn_out_cols: StateBatch = eval_results[policy_id][1]
pi_info_cols: dict = eval_results[policy_id][2] extra_action_out_cols: dict = eval_results[policy_id][2]
# In case actions is a list (representing the 0th dim of a batch of # In case actions is a list (representing the 0th dim of a batch of
# primitive actions), try converting it first. # primitive actions), try converting it first.
@ -1107,7 +1137,7 @@ def _process_policy_eval_results(
# Store RNN state ins/outs and extra-action fetches to episode. # Store RNN state ins/outs and extra-action fetches to episode.
for f_i, column in enumerate(rnn_out_cols): for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column extra_action_out_cols["state_out_{}".format(f_i)] = column
policy: Policy = _get_or_raise(policies, policy_id) policy: Policy = _get_or_raise(policies, policy_id)
# Split action-component batches into single action rows. # Split action-component batches into single action rows.
@ -1127,11 +1157,11 @@ def _process_policy_eval_results(
env_id: int = eval_data[i].env_id env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id agent_id: AgentID = eval_data[i].agent_id
episode: MultiAgentEpisode = active_episodes[env_id] episode: Episode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info( episode._set_last_extra_action_outs(
agent_id, {k: v[i] agent_id, {k: v[i]
for k, v in pi_info_cols.items()}) for k, v in extra_action_out_cols.items()})
if env_id in off_policy_actions and \ if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]: agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id, episode._set_last_action(agent_id,

View file

@ -81,7 +81,7 @@ class EchoPolicy(Policy):
return obs_batch.argmax(axis=1), [], {} return obs_batch.argmax(axis=1), [], {}
class MultiAgentEpisodeEnv(MultiAgentEnv): class EpisodeEnv(MultiAgentEnv):
def __init__(self, episode_length, num): def __init__(self, episode_length, num):
self.agents = [MockEnv3(episode_length) for _ in range(num)] self.agents = [MockEnv3(episode_length) for _ in range(num)]
self.dones = set() self.dones = set()
@ -123,9 +123,9 @@ class TestEpisodeLastValues(unittest.TestCase):
ev.sample() ev.sample()
def test_multiagent_env(self): def test_multiagent_env(self):
temp_env = MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS) temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS)
ev = RolloutWorker( ev = RolloutWorker(
env_creator=lambda _: MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS), env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS),
policy_spec={ policy_spec={
str(agent_id): (EchoPolicy, temp_env.observation_space, str(agent_id): (EchoPolicy, temp_env.observation_space,
temp_env.action_space, {}) temp_env.action_space, {})

View file

@ -29,9 +29,9 @@ T = TypeVar("T")
@DeveloperAPI @DeveloperAPI
class WorkerSet: class WorkerSet:
"""Represents a set of RolloutWorkers. """Set of RolloutWorkers with n @ray.remote workers and one local worker.
There must be one local worker copy, and zero or more remote workers. Where n may be 0.
""" """
def __init__(self, def __init__(self,
@ -43,20 +43,20 @@ class WorkerSet:
num_workers: int = 0, num_workers: int = 0,
logdir: Optional[str] = None, logdir: Optional[str] = None,
_setup: bool = True): _setup: bool = True):
"""Create a new WorkerSet and initialize its workers. """Initializes a WorkerSet instance.
Args: Args:
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function env_creator: Function that returns env given env config.
that returns env given env config. validate_env: Optional callable to validate the generated
validate_env (Optional[Callable[[EnvType], None]]): Optional environment (only on worker=0).
callable to validate the generated environment (only on policy_class: An optional Policy class. If None, PolicySpecs can be
worker=0). generated automatically by using the Trainer's default class
policy (Optional[Type[Policy]]): A rllib.policy.Policy class. of via a given multi-agent policy config dict.
trainer_config (Optional[TrainerConfigDict]): Optional dict that trainer_config: Optional dict that extends the common config of
extends the common config of the Trainer class. the Trainer class.
num_workers (int): Number of remote rollout workers to create. num_workers: Number of remote rollout workers to create.
logdir (Optional[str]): Optional logging directory for workers. logdir: Optional logging directory for workers.
_setup (bool): Whether to setup workers. This is only for testing. _setup: Whether to setup workers. This is only for testing.
""" """
if not trainer_config: if not trainer_config:
@ -119,15 +119,20 @@ class WorkerSet:
) )
def local_worker(self) -> RolloutWorker: def local_worker(self) -> RolloutWorker:
"""Return the local rollout worker.""" """Returns the local rollout worker."""
return self._local_worker return self._local_worker
def remote_workers(self) -> List[ActorHandle]: def remote_workers(self) -> List[ActorHandle]:
"""Return a list of remote rollout workers.""" """Returns a list of remote rollout workers."""
return self._remote_workers return self._remote_workers
def sync_weights(self, policies: Optional[List[PolicyID]] = None) -> None: def sync_weights(self, policies: Optional[List[PolicyID]] = None) -> None:
"""Syncs weights from the local worker to all remote workers.""" """Syncs model weights from the local worker to all remote workers.
Args:
policies: An optional list of policy IDs to sync for. If None,
sync all policies.
"""
if self.remote_workers(): if self.remote_workers():
weights = ray.put(self.local_worker().get_weights(policies)) weights = ray.put(self.local_worker().get_weights(policies))
for e in self.remote_workers(): for e in self.remote_workers():
@ -136,8 +141,11 @@ class WorkerSet:
def add_workers(self, num_workers: int) -> None: def add_workers(self, num_workers: int) -> None:
"""Creates and adds a number of remote workers to this worker set. """Creates and adds a number of remote workers to this worker set.
Can be called several times on the same WorkerSet to add more
RolloutWorkers to the set.
Args: Args:
num_workers (int): The number of remote Workers to add to this num_workers: The number of remote Workers to add to this
WorkerSet. WorkerSet.
""" """
remote_args = { remote_args = {
@ -159,25 +167,36 @@ class WorkerSet:
]) ])
def reset(self, new_remote_workers: List[ActorHandle]) -> None: def reset(self, new_remote_workers: List[ActorHandle]) -> None:
"""Called to change the set of remote workers.""" """Hard overrides the remote workers in this set with the given one.
Args:
new_remote_workers: A list of new RolloutWorkers
(as `ActorHandles`) to use as remote workers.
"""
self._remote_workers = new_remote_workers self._remote_workers = new_remote_workers
def stop(self) -> None: def stop(self) -> None:
"""Stop all rollout workers.""" """Calls `stop` on all rollout workers (including the local one)."""
try: try:
self.local_worker().stop() self.local_worker().stop()
tids = [w.stop.remote() for w in self.remote_workers()] tids = [w.stop.remote() for w in self.remote_workers()]
ray.get(tids) ray.get(tids)
except Exception: except Exception:
logger.exception("Failed to stop workers") logger.exception("Failed to stop workers!")
finally: finally:
for w in self.remote_workers(): for w in self.remote_workers():
w.__ray_terminate__.remote() w.__ray_terminate__.remote()
@DeveloperAPI @DeveloperAPI
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]: def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
"""Apply the given function to each worker instance.""" """Calls the given function with each worker instance as arg.
Args:
func: The function to call for each worker (as only arg).
Returns:
The list of return values of all calls to `func([worker])`.
"""
local_result = [func(self.local_worker())] local_result = [func(self.local_worker())]
remote_results = ray.get( remote_results = ray.get(
[w.apply.remote(func) for w in self.remote_workers()]) [w.apply.remote(func) for w in self.remote_workers()])
@ -186,11 +205,23 @@ class WorkerSet:
@DeveloperAPI @DeveloperAPI
def foreach_worker_with_index( def foreach_worker_with_index(
self, func: Callable[[RolloutWorker, int], T]) -> List[T]: self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
"""Apply the given function to each worker instance. """Calls `func` with each worker instance and worker idx as args.
The index will be passed as the second arg to the given function. The index will be passed as the second arg to the given function.
Args:
func: The function to call for each worker and its index
(as args). The local worker has index 0, all remote workers
have indices > 0.
Returns:
The list of return values of all calls to `func([worker, idx])`.
The first entry in this list are the results of the local
worker, followed by all remote workers' results.
""" """
# Local worker: Index=0.
local_result = [func(self.local_worker(), 0)] local_result = [func(self.local_worker(), 0)]
# Remote workers: Index > 0.
remote_results = ray.get([ remote_results = ray.get([
w.apply.remote(func, i + 1) w.apply.remote(func, i + 1)
for i, w in enumerate(self.remote_workers()) for i, w in enumerate(self.remote_workers())
@ -199,15 +230,22 @@ class WorkerSet:
@DeveloperAPI @DeveloperAPI
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]: def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
"""Apply the given function to each worker's (policy, policy_id) tuple. """Calls `func` with each worker's (policy, PolicyID) tuple.
Note that in the multi-agent case, each worker may have more than one
policy.
Args: Args:
func (callable): A function - taking a Policy and its ID - that is func: A function - taking a Policy and its ID - that is
called on all workers' Policies. called on all workers' Policies.
Returns: Returns:
List[any]: The list of return values of func over all workers' The list of return values of func over all workers' policies. The
policies. length of this list is:
(num_workers + 1 (local-worker)) *
[num policies in the multi-agent config dict].
The local workers' results are first, followed by all remote
workers' results
""" """
results = self.local_worker().foreach_policy(func) results = self.local_worker().foreach_policy(func)
ray_gets = [] ray_gets = []
@ -221,8 +259,8 @@ class WorkerSet:
@DeveloperAPI @DeveloperAPI
def trainable_policies(self) -> List[PolicyID]: def trainable_policies(self) -> List[PolicyID]:
"""Return the list of trainable policy ids.""" """Returns the list of trainable policy ids."""
return self.local_worker().foreach_trainable_policy(lambda _, pid: pid) return self.local_worker().policies_to_train
@DeveloperAPI @DeveloperAPI
def foreach_trainable_policy( def foreach_trainable_policy(
@ -230,7 +268,7 @@ class WorkerSet:
"""Apply `func` to all workers' Policies iff in `policies_to_train`. """Apply `func` to all workers' Policies iff in `policies_to_train`.
Args: Args:
func (callable): A function - taking a Policy and its ID - that is func: A function - taking a Policy and its ID - that is
called on all workers' Policies in `worker.policies_to_train`. called on all workers' Policies in `worker.policies_to_train`.
Returns: Returns:
@ -250,7 +288,7 @@ class WorkerSet:
@DeveloperAPI @DeveloperAPI
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]: def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
"""Apply `func` to all workers' underlying sub environments. """Calls `func` with all workers' sub environments as args.
An "underlying sub environment" is a single clone of an env within An "underlying sub environment" is a single clone of an env within
a vectorized environment. a vectorized environment.
@ -258,13 +296,12 @@ class WorkerSet:
gym.Env object. gym.Env object.
Args: Args:
func (Callable[[EnvType], T]): A function - taking an EnvType func: A function - taking an EnvType (normally a gym.Env object)
(normally a gym.Env object) as arg and returning a list of as arg and returning a list of lists of return values, one
return values over sub environments for each worker. value per underlying sub-environment per each worker.
Returns: Returns:
List[List[T]]: The list (workers) of lists (sub environments) of The list (workers) of lists (sub environments) of results.
results.
""" """
local_results = [self.local_worker().foreach_env(func)] local_results = [self.local_worker().foreach_env(func)]
ray_gets = [] ray_gets = []
@ -276,7 +313,7 @@ class WorkerSet:
def foreach_env_with_context( def foreach_env_with_context(
self, self,
func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]: func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
"""Apply `func` to all workers' underlying sub environments. """Call `func` with all workers' sub-environments and env_ctx as args.
An "underlying sub environment" is a single clone of an env within An "underlying sub environment" is a single clone of an env within
a vectorized environment. a vectorized environment.
@ -284,13 +321,13 @@ class WorkerSet:
as args. as args.
Args: Args:
func (Callable[[BaseEnv], T]): A function - taking a BaseEnv func: A function - taking a BaseEnv object and an EnvContext as
object as arg and returning a list of return values over envs arg - and returning a list of lists of return values over envs
of the worker. of the worker.
Returns: Returns:
List[List[T]]: The list (workers) of lists (environments) of The list (1 item per workers) of lists (1 item per sub-environment)
results. of results.
""" """
local_results = [self.local_worker().foreach_env_with_context(func)] local_results = [self.local_worker().foreach_env_with_context(func)]
ray_gets = [] ray_gets = []

View file

@ -13,7 +13,7 @@ import ray
from ray import tune from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
@ -28,8 +28,8 @@ parser.add_argument("--stop-iters", type=int, default=2000)
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], policies: Dict[str, Policy], episode: Episode,
episode: MultiAgentEpisode, env_index: int, **kwargs): env_index: int, **kwargs):
# Make sure this episode has just been started (only initial obs # Make sure this episode has just been started (only initial obs
# logged so far). # logged so far).
assert episode.length == 0, \ assert episode.length == 0, \
@ -41,8 +41,8 @@ class MyCallbacks(DefaultCallbacks):
episode.hist_data["pole_angles"] = [] episode.hist_data["pole_angles"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], policies: Dict[str, Policy], episode: Episode,
episode: MultiAgentEpisode, env_index: int, **kwargs): env_index: int, **kwargs):
# Make sure this episode is ongoing. # Make sure this episode is ongoing.
assert episode.length > 0, \ assert episode.length > 0, \
"ERROR: `on_episode_step()` callback should not be called right " \ "ERROR: `on_episode_step()` callback should not be called right " \
@ -53,7 +53,7 @@ class MyCallbacks(DefaultCallbacks):
episode.user_data["pole_angles"].append(pole_angle) episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode, policies: Dict[str, Policy], episode: Episode,
env_index: int, **kwargs): env_index: int, **kwargs):
# Make sure this episode is really done. # Make sure this episode is really done.
assert episode.batch_builder.policy_collectors[ assert episode.batch_builder.policy_collectors[
@ -84,8 +84,8 @@ class MyCallbacks(DefaultCallbacks):
policy, result["sum_actions_in_train_batch"])) policy, result["sum_actions_in_train_batch"]))
def on_postprocess_trajectory( def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode, self, *, worker: RolloutWorker, episode: Episode, agent_id: str,
agent_id: str, policy_id: str, policies: Dict[str, Policy], policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch, postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs): original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed {} steps".format(postprocessed_batch.count)) print("postprocessed {} steps".format(postprocessed_batch.count))

View file

@ -23,7 +23,7 @@ tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.evaluation import Episode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -142,7 +142,7 @@ class Policy(metaclass=ABCMeta):
prev_reward: Optional[TensorStructType] = None, prev_reward: Optional[TensorStructType] = None,
info: dict = None, info: dict = None,
input_dict: Optional[SampleBatch] = None, input_dict: Optional[SampleBatch] = None,
episode: Optional["MultiAgentEpisode"] = None, episode: Optional["Episode"] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
# Kwars placeholder for future compatibility. # Kwars placeholder for future compatibility.
@ -238,7 +238,7 @@ class Policy(metaclass=ABCMeta):
input_dict: Union[SampleBatch, Dict[str, TensorStructType]], input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["Episode"]] = None,
**kwargs) -> \ **kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents). """Computes actions from collected samples (across multiple-agents).
@ -300,7 +300,7 @@ class Policy(metaclass=ABCMeta):
prev_reward_batch: Union[List[TensorStructType], prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None, TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None, info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs) -> \ **kwargs) -> \
@ -313,7 +313,7 @@ class Policy(metaclass=ABCMeta):
prev_action_batch: Batch of previous action values. prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards. prev_reward_batch: Batch of previous rewards.
info_batch: Batch of info objects. info_batch: Batch of info objects.
episodes: List of MultiAgentEpisodes, one for each obs in episodes: List of Episode objects, one for each obs in
obs_batch. This provides access to all of the internal obs_batch. This provides access to all of the internal
episode state, which may be useful for model-based or episode state, which may be useful for model-based or
multi-agent algorithms. multi-agent algorithms.
@ -377,7 +377,7 @@ class Policy(metaclass=ABCMeta):
sample_batch: SampleBatch, sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, Tuple[ other_agent_batches: Optional[Dict[AgentID, Tuple[
"Policy", SampleBatch]]] = None, "Policy", SampleBatch]]] = None,
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch: episode: Optional["Episode"] = None) -> SampleBatch:
"""Implements algorithm-specific trajectory postprocessing. """Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy This will be called on each trajectory fragment computed during policy

View file

@ -19,7 +19,7 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict TrainerConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa from ray.rllib.evaluation.episode import Episode # noqa
jax, _ = try_import_jax() jax, _ = try_import_jax()
torch, _ = try_import_torch() torch, _ = try_import_torch()
@ -39,7 +39,7 @@ def build_policy_class(
str, TensorType]]] = None, str, TensorType]]] = None,
postprocess_fn: Optional[Callable[[ postprocess_fn: Optional[Callable[[
Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[ Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
"MultiAgentEpisode"] "Episode"]
], SampleBatch]] = None, ], SampleBatch]] = None,
extra_action_out_fn: Optional[Callable[[ extra_action_out_fn: Optional[Callable[[
Policy, Dict[str, TensorType], List[TensorType], ModelV2, Policy, Dict[str, TensorType], List[TensorType], ModelV2,
@ -98,7 +98,7 @@ def build_policy_class(
overrides. If None, uses only(!) the user-provided overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy. PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch, postprocess_fn (Optional[Callable[[Policy, SampleBatch,
Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]], Optional[Dict[Any, SampleBatch]], Optional["Episode"]],
SampleBatch]]): Optional callable for post-processing experience SampleBatch]]): Optional callable for post-processing experience
batches (called after the super's `postprocess_trajectory` method). batches (called after the super's `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch], stats_fn (Optional[Callable[[Policy, SampleBatch],

View file

@ -28,7 +28,7 @@ from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
TensorType, TrainerConfigDict TensorType, TrainerConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.evaluation import Episode
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -277,7 +277,7 @@ class TFPolicy(Policy):
input_dict: Union[SampleBatch, Dict[str, TensorType]], input_dict: Union[SampleBatch, Dict[str, TensorType]],
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["Episode"]] = None,
**kwargs) -> \ **kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
@ -308,7 +308,7 @@ class TFPolicy(Policy):
prev_action_batch: Union[List[TensorType], TensorType] = None, prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None, info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs): **kwargs):

View file

@ -18,7 +18,7 @@ from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \
TrainerConfigDict TrainerConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.evaluation import Episode
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -34,7 +34,7 @@ def build_tf_policy(
TrainerConfigDict]] = None, TrainerConfigDict]] = None,
postprocess_fn: Optional[Callable[[ postprocess_fn: Optional[Callable[[
Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]], Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
Optional["MultiAgentEpisode"] Optional["Episode"]
], SampleBatch]] = None, ], SampleBatch]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None, str, TensorType]]] = None,
@ -107,7 +107,7 @@ def build_tf_policy(
overrides. If None, uses only(!) the user-provided overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy. PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch, postprocess_fn (Optional[Callable[[Policy, SampleBatch,
Optional[Dict[AgentID, SampleBatch]], MultiAgentEpisode], None]]): Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
Optional callable for post-processing experience batches (called Optional callable for post-processing experience batches (called
after the parent class' `postprocess_trajectory` method). after the parent class' `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch], stats_fn (Optional[Callable[[Policy, SampleBatch],

View file

@ -31,7 +31,7 @@ from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TensorStructType, TrainerConfigDict TensorStructType, TrainerConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa from ray.rllib.evaluation import Episode # noqa
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -279,7 +279,7 @@ class TorchPolicy(Policy):
prev_reward_batch: Union[List[TensorStructType], prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None, TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None, info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs) -> \ **kwargs) -> \

View file

@ -7,7 +7,7 @@ import ray
from ray.tune.registry import register_env from ray.tune.registry import register_env
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
@ -336,9 +336,9 @@ class TestMultiAgentEnv(unittest.TestCase):
# Pretend we did a model-based rollout and want to return # Pretend we did a model-based rollout and want to return
# the extra trajectory. # the extra trajectory.
env_id = episodes[0].env_id env_id = episodes[0].env_id
fake_eps = MultiAgentEpisode( fake_eps = Episode(episodes[0].policy_map,
episodes[0].policy_map, episodes[0].policy_mapping_fn, episodes[0].policy_mapping_fn,
lambda: None, lambda x: None, env_id) lambda: None, lambda x: None, env_id)
builder = get_global_worker().sampler.sample_collector builder = get_global_worker().sampler.sample_collector
agent_id = "extra_0" agent_id = "extra_0"
policy_id = "p1" # use p1 so we can easily check it policy_id = "p1" # use p1 so we can easily check it