[RLlib; Docs overhaul] Docstring cleanup: Evaluation (#19783)

This commit is contained in:
Sven Mika 2021-10-29 12:03:56 +02:00 committed by GitHub
parent f2773267c7
commit 9c73871da0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 1059 additions and 705 deletions

View file

@ -5,20 +5,20 @@ import gym
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID, LocalOptimizer, ModelGradients
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.evaluation import MultiAgentEpisode
tf1, tf, tfv = try_import_tf()
@ -31,7 +31,7 @@ def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)

View file

@ -3,7 +3,7 @@ from typing import Optional, Dict
import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
@ -30,7 +30,7 @@ def add_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import AgentID, PolicyID
@ -35,28 +35,23 @@ class DefaultCallbacks:
"a class extending rllib.agents.callbacks.DefaultCallbacks")
self.legacy_callbacks = legacy_callbacks_dict or {}
def on_episode_start(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: Episode,
**kwargs) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be received by calling
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode object which contains episode
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index (EnvID): Obsoleted: The ID of the environment, which the
env_index: Obsoleted: The ID of the environment, which the
episode belongs to.
kwargs: Forward compatibility placeholder.
"""
@ -73,20 +68,19 @@ class DefaultCallbacks:
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
episode: Episode,
**kwargs) -> None:
"""Runs on each episode step.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
sub environment objects can be gotten by calling
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id
to policy objects. In single agent mode there will only be a
single "default_policy".
episode (MultiAgentEpisode): Episode object which contains episode
episode (Episode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
@ -101,13 +95,8 @@ class DefaultCallbacks:
"episode": episode
})
def on_episode_end(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: Episode,
**kwargs) -> None:
"""Runs when an episode is done.
@ -119,7 +108,7 @@ class DefaultCallbacks:
policies (Dict[PolicyID, Policy]): Mapping of policy id to policy
objects. In single agent mode there will only be a single
"default_policy".
episode (MultiAgentEpisode): Episode object which contains episode
episode (Episode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
@ -136,7 +125,7 @@ class DefaultCallbacks:
})
def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
self, *, worker: "RolloutWorker", episode: Episode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
@ -148,7 +137,7 @@ class DefaultCallbacks:
Args:
worker (RolloutWorker): Reference to the current rollout worker.
episode (MultiAgentEpisode): Episode object.
episode (Episode): Episode object.
agent_id (str): Id of the current agent.
policy_id (str): Id of the current policy for the agent.
policies (dict): Mapping of policy id to policy objects. In single
@ -253,7 +242,7 @@ class MemoryTrackingCallbacks(DefaultCallbacks):
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
episode: Episode,
env_index: Optional[int] = None,
**kwargs) -> None:
snapshot = tracemalloc.take_snapshot()
@ -311,7 +300,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
episode: Episode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
@ -328,7 +317,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: MultiAgentEpisode,
episode: Episode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
@ -345,7 +334,7 @@ class MultiCallbacks(DefaultCallbacks):
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
episode: Episode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
@ -358,7 +347,7 @@ class MultiCallbacks(DefaultCallbacks):
**kwargs)
def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
self, *, worker: "RolloutWorker", episode: Episode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional
import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
@ -247,7 +247,7 @@ def preprocess_episode(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.
"""

View file

@ -59,7 +59,7 @@ def postprocess_advantages(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:

View file

@ -1,6 +1,6 @@
from typing import List, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
@ -10,7 +10,7 @@ def post_process_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[List[SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
"""Adds the "advantages" column to `sample_batch`.
Args:
@ -18,7 +18,7 @@ def post_process_advantages(
sample_batch (SampleBatch): The actual sample batch to post-process.
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
other agents' SampleBatch objects.
episode (MultiAgentEpisode): The multi-agent episode object, from which
episode (Episode): The multi-agent episode object, from which
`sample_batch` was generated.
Returns:

View file

@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Type, Union
from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, choose_optimizer
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
@ -325,7 +325,7 @@ def postprocess_trajectory(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
@ -343,7 +343,7 @@ def postprocess_trajectory(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:

View file

@ -7,7 +7,7 @@ import logging
from typing import Dict, List, Optional, Type, Union
import ray
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.modelv2 import ModelV2
@ -357,7 +357,7 @@ def postprocess_ppo_gae(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)

View file

@ -16,7 +16,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
@ -106,7 +106,7 @@ def postprocess_trajectory(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
@ -124,7 +124,7 @@ def postprocess_trajectory(
other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:

View file

@ -19,7 +19,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.utils import gym_env_creator
from ray.rllib.evaluation.collectors.simple_list_collector import \
SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
@ -1086,7 +1086,7 @@ class Trainer(Trainable):
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episode: Optional[MultiAgentEpisode] = None,
episode: Optional[Episode] = None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
@ -1240,7 +1240,7 @@ class Trainer(Trainable):
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episodes: Optional[List[MultiAgentEpisode]] = None,
episodes: Optional[List[Episode]] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
# Deprecated.

View file

@ -1,4 +1,4 @@
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode, MultiAgentEpisode
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.sample_batch_builder import (
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
@ -17,5 +17,6 @@ __all__ = [
"AsyncSampler",
"compute_advantages",
"collect_metrics",
"MultiAgentEpisode",
"Episode",
"MultiAgentEpisode", # Deprecated -> Use `Episode` instead.
]

View file

@ -2,7 +2,7 @@ from abc import abstractmethod, ABCMeta
import logging
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
@ -57,7 +57,7 @@ class SampleCollector(metaclass=ABCMeta):
self.count_steps_by = count_steps_by
@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
def add_init_obs(self, episode: Episode, agent_id: AgentID,
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.
@ -70,7 +70,7 @@ class SampleCollector(metaclass=ABCMeta):
called for that same agent/episode-pair.
Args:
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we
episode (Episode): The Episode, for which we
are adding an Agent's initial observation.
agent_id (AgentID): Unique id for the agent we are adding
values for.
@ -126,11 +126,11 @@ class SampleCollector(metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def episode_step(self, episode: MultiAgentEpisode) -> None:
def episode_step(self, episode: Episode) -> None:
"""Increases the episode step counter (across all agents) by one.
Args:
episode (MultiAgentEpisode): Episode we are stepping through.
episode (Episode): Episode we are stepping through.
Useful for handling counting b/c it is called once across
all agents that are inside this episode.
"""
@ -200,7 +200,7 @@ class SampleCollector(metaclass=ABCMeta):
@abstractmethod
def postprocess_episode(self,
episode: MultiAgentEpisode,
episode: Episode,
is_done: bool = False,
check_dones: bool = False,
build: bool = False) -> Optional[MultiAgentBatch]:
@ -214,7 +214,7 @@ class SampleCollector(metaclass=ABCMeta):
correctly added to the buffers.
Args:
episode (MultiAgentEpisode): The Episode object for which
episode (Episode): The Episode object for which
to post-process data.
is_done (bool): Whether the given episode is actually terminated
(all agents are done OR we hit a hard horizon). If True, the

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
@ -505,11 +505,11 @@ class SimpleListCollector(SampleCollector):
# Maps episode ID to the (non-built) individual agent steps in this
# episode.
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
# Maps episode ID to MultiAgentEpisode.
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
# Maps episode ID to Episode.
self.episodes: Dict[EpisodeID, Episode] = {}
@override(SampleCollector)
def episode_step(self, episode: MultiAgentEpisode) -> None:
def episode_step(self, episode: Episode) -> None:
episode_id = episode.episode_id
# In the rase case that an "empty" step is taken at the beginning of
# the episode (none of the agents has an observation in the obs-dict
@ -550,8 +550,8 @@ class SimpleListCollector(SampleCollector):
if not self.multiple_episodes_in_batch else ""))
@override(SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID, t: int,
def add_init_obs(self, episode: Episode, agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
# Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id)
@ -707,7 +707,7 @@ class SimpleListCollector(SampleCollector):
@override(SampleCollector)
def postprocess_episode(
self,
episode: MultiAgentEpisode,
episode: Episode,
is_done: bool = False,
check_dones: bool = False,
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
@ -834,7 +834,7 @@ class SimpleListCollector(SampleCollector):
if build:
return self._build_multi_agent_batch(episode)
def _build_multi_agent_batch(self, episode: MultiAgentEpisode) -> \
def _build_multi_agent_batch(self, episode: Episode) -> \
Union[MultiAgentBatch, SampleBatch]:
ma_batch = {}

View file

@ -1,11 +1,11 @@
from collections import defaultdict
import numpy as np
import random
from typing import List, Dict, Callable, Any, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
@ -19,7 +19,7 @@ if TYPE_CHECKING:
@DeveloperAPI
class MultiAgentEpisode:
class Episode:
"""Tracks the current state of a (possibly multi-agent) episode.
Attributes:
@ -53,15 +53,28 @@ class MultiAgentEpisode:
def __init__(
self,
policies: PolicyMap,
policy_mapping_fn: Callable[
[AgentID, "MultiAgentEpisode", "RolloutWorker"], PolicyID],
policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
PolicyID],
batch_builder_factory: Callable[[],
"MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None],
env_id: EnvID,
*,
worker: "RolloutWorker" = None,
worker: Optional["RolloutWorker"] = None,
):
"""Initializes an Episode instance.
Args:
policies: The PolicyMap object (mapping PolicyIDs to Policy
objects) to use for determining, which policy is used for
which agent.
policy_mapping_fn: The mapping function mapping AgentIDs to
PolicyIDs.
batch_builder_factory:
extra_batch_callback:
env_id: The environment's ID in which this episode runs.
worker: The RolloutWorker instance, in which this episode runs.
"""
self.new_batch_builder: Callable[
[], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
self.add_extra_batch: Callable[[SampleBatchType],
@ -80,9 +93,8 @@ class MultiAgentEpisode:
self.media: Dict[str, Any] = {}
self.policy_map: PolicyMap = policies
self._policies = self.policy_map # backward compatibility
self.policy_mapping_fn: Callable[[
AgentID, "MultiAgentEpisode", "RolloutWorker"
], PolicyID] = policy_mapping_fn
self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
PolicyID] = policy_mapping_fn
self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {}
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
@ -92,21 +104,11 @@ class MultiAgentEpisode:
self._agent_to_last_done: Dict[AgentID, bool] = {}
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
self._agent_to_last_pi_info: Dict[AgentID, dict] = {}
self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
list)
# TODO: Deprecated.
@property
def _policy_mapping_fn(self):
deprecation_warning(
old="MultiAgentEpisode._policy_mapping_fn",
new="MultiAgentEpisode.policy_mapping_fn",
error=False,
)
return self.policy_mapping_fn
@DeveloperAPI
def soft_reset(self) -> None:
"""Clears rewards and metrics, but retains RNN and other state.
@ -126,15 +128,17 @@ class MultiAgentEpisode:
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the entire episode (even if the
policy_mapping_fn is changed in the meantime).
policy_mapping_fn is changed in the meantime!).
Args:
agent_id (AgentID): The agent ID to lookup the policy ID for.
agent_id: The agent ID to lookup the policy ID for.
Returns:
PolicyID: The policy ID for the specified agent.
The policy ID for the specified agent.
"""
# Perform a new policy_mapping_fn lookup and bind AgentID for the
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy:
# Try new API: pass in agent_id and episode as named args.
# New signature should be: (agent_id, episode, worker, **kwargs)
@ -153,8 +157,11 @@ class MultiAgentEpisode:
self.policy_mapping_fn(agent_id)
else:
raise e
# Use already determined PolicyID.
else:
policy_id = self._agent_to_policy[agent_id]
# PolicyID not found in policy map -> Error.
if policy_id not in self.policy_map:
raise KeyError("policy_mapping_fn returned invalid policy id "
f"'{policy_id}'!")
@ -162,33 +169,70 @@ class MultiAgentEpisode:
@DeveloperAPI
def last_observation_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType:
"""Returns the last observation for the specified agent."""
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last observation for the specified AgentID.
Args:
agent_id: The agent's ID to get the last observation for.
Returns:
Last observation the specified AgentID has seen. None in case
the agent has never made any observations in the episode.
"""
return self._agent_to_last_obs.get(agent_id)
@DeveloperAPI
def last_raw_obs_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType:
"""Returns the last un-preprocessed obs for the specified agent."""
def last_raw_obs_for(
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
"""Returns the last un-preprocessed obs for the specified AgentID.
Args:
agent_id: The agent's ID to get the last un-preprocessed
observation for.
Returns:
Last un-preprocessed observation the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_raw_obs.get(agent_id)
@DeveloperAPI
def last_info_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvInfoDict:
"""Returns the last info for the specified agent."""
def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
) -> Optional[EnvInfoDict]:
"""Returns the last info for the specified AgentID.
Args:
agent_id: The agent's ID to get the last info for.
Returns:
Last info dict the specified AgentID has seen.
None in case the agent has never made any observations in the
episode.
"""
return self._agent_to_last_info.get(agent_id)
@DeveloperAPI
def last_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the last action for the specified agent, or zeros."""
"""Returns the last action for the specified AgentID, or zeros.
The "last" action is the most recent one taken by the agent.
Args:
agent_id: The agent's ID to get the last action for.
Returns:
Last action the specified AgentID has executed.
Zeros in case the agent has never performed any actions in the
episode.
"""
# Agent has already taken at least one action in the episode.
if agent_id in self._agent_to_last_action:
return flatten_to_single_ndarray(
self._agent_to_last_action[agent_id])
# Agent has not acted yet, return all zeros.
else:
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
@ -200,29 +244,84 @@ class MultiAgentEpisode:
@DeveloperAPI
def prev_action_for(self,
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
"""Returns the previous action for the specified agent."""
"""Returns the previous action for the specified agent, or zeros.
The "previous" action is the one taken one timestep before the
most recent action taken by the agent.
Args:
agent_id: The agent's ID to get the previous action for.
Returns:
Previous action the specified AgentID has executed.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
# We are at t > 1 -> There has been a previous action by this agent.
if agent_id in self._agent_to_prev_action:
return flatten_to_single_ndarray(
self._agent_to_prev_action[agent_id])
# We're at t <= 1, so return all zeros.
else:
# We're at t=0, so return all zeros.
return np.zeros_like(self.last_action_for(agent_id))
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent."""
def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the last reward for the specified agent, or zero.
The "last" reward is the one received most recently by the agent.
Args:
agent_id: The agent's ID to get the last reward for.
Returns:
Last reward for the the specified AgentID.
Zero in case the agent has never performed any actions
(and thus received rewards) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 0 -> Return previously received reward.
if len(history) >= 1:
return history[-1]
# We're at t=0, so there is no previous reward, just return zero.
else:
return 0.0
@DeveloperAPI
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
"""Returns the previous reward for the specified agent, or zero.
The "previous" reward is the one received one timestep before the
most recently received reward of the agent.
Args:
agent_id: The agent's ID to get the previous reward for.
Returns:
Previous reward for the the specified AgentID.
Zero in case the agent has never performed any actions (or only
one) in the episode.
"""
history = self._agent_reward_history[agent_id]
# We are at t > 1 -> Return reward prior to most recent (last) one.
if len(history) >= 2:
return history[-2]
# We're at t <= 1, so there is no previous reward, just return zero.
else:
# We're at t=0, so there is no previous reward, just return zero.
return 0.0
@DeveloperAPI
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
"""Returns the last RNN state for the specified agent."""
"""Returns the last RNN state for the specified agent.
Args:
agent_id: The agent's ID to get the most recent RNN state for.
Returns:
Most recent RNN state of the the specified AgentID.
"""
if agent_id not in self._agent_to_rnn_state:
policy_id = self.policy_for(agent_id)
@ -232,24 +331,44 @@ class MultiAgentEpisode:
@DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
"""Returns the last done flag received for the specified agent."""
"""Returns the last done flag for the specified AgentID.
Args:
agent_id: The agent's ID to get the last done flag for.
Returns:
Last done flag for the specified AgentID.
"""
if agent_id not in self._agent_to_last_done:
self._agent_to_last_done[agent_id] = False
return self._agent_to_last_done[agent_id]
@DeveloperAPI
def last_pi_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> dict:
"""Returns the last info object for the specified agent."""
def last_extra_action_outs_for(
self,
agent_id: AgentID = _DUMMY_AGENT_ID,
) -> dict:
"""Returns the last extra-action outputs for the specified agent.
return self._agent_to_last_pi_info[agent_id]
This data is returned by a call to
`Policy.compute_actions_from_input_dict` as the 3rd return value
(1st return value = action; 2nd return value = RNN state outs).
Args:
agent_id: The agent's ID to get the last extra-action outs for.
Returns:
The last extra-action outs for the specified AgentID.
"""
return self._agent_to_last_extra_action_outs[agent_id]
@DeveloperAPI
def get_agents(self) -> List[AgentID]:
"""Returns list of agent IDs that have appeared in this episode.
Returns:
List[AgentID]: The list of all agents that have appeared so
far in this episode.
The list of all agent IDs that have appeared so far in this
episode.
"""
return list(self._agent_to_index.keys())
@ -282,11 +401,31 @@ class MultiAgentEpisode:
self._agent_to_last_action[agent_id]
self._agent_to_last_action[agent_id] = action
def _set_last_pi_info(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info
def _set_last_extra_action_outs(self, agent_id, pi_info):
self._agent_to_last_extra_action_outs[agent_id] = pi_info
def _agent_index(self, agent_id):
if agent_id not in self._agent_to_index:
self._agent_to_index[agent_id] = self._next_agent_index
self._next_agent_index += 1
return self._agent_to_index[agent_id]
@property
def _policy_mapping_fn(self):
deprecation_warning(
old="Episode._policy_mapping_fn",
new="Episode.policy_mapping_fn",
error=False,
)
return self.policy_mapping_fn
@Deprecated(new="Episode.last_extra_action_outs_for", error=False)
def last_pi_info_for(self, *args, **kwargs):
return self.last_extra_action_outs_for(*args, **kwargs)
# Backward compatibility. The name Episode implies that there is
# also a (single agent?) Episode.
@Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
class MultiAgentEpisode(Episode):
pass

View file

@ -1,14 +1,13 @@
import collections
import logging
import numpy as np
import collections
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import ray
from ray import ObjectRef
from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
@ -18,6 +17,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length",
"episode_reward",
"agent_rewards",
"custom_metrics",
"perf_stats",
"hist_data",
"media",
])
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})
def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
if key in stats:

View file

@ -2,7 +2,7 @@ from typing import Dict
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.utils.framework import TensorType
from ray.rllib.utils.typing import AgentID, PolicyID
@ -22,7 +22,7 @@ class ObservationFunction:
def __call__(self, agent_obs: Dict[AgentID, TensorType],
worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode,
policies: Dict[PolicyID, Policy], episode: Episode,
**kw) -> Dict[AgentID, TensorType]:
"""Callback run on each environment step to observe the environment.
@ -45,7 +45,7 @@ class ObservationFunction:
retrieved by calling `base_env.get_sub_environments()`.
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode state object.
episode (Episode): Episode state object.
kwargs: Forward compatibility placeholder.
Returns:

View file

@ -2,7 +2,7 @@ import numpy as np
import scipy.signal
from typing import Dict, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
@ -135,7 +135,7 @@ def compute_gae_for_sample_batch(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
episode: Optional[Episode] = None) -> SampleBatch:
"""Adds GAE (generalized advantage estimations) to a trajectory.
The trajectory contains only data from one episode and from one agent.
@ -153,7 +153,7 @@ def compute_gae_for_sample_batch(
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:

View file

@ -1,13 +0,0 @@
import collections
# Define this in its own file, see #5125
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length",
"episode_reward",
"agent_rewards",
"custom_metrics",
"perf_stats",
"hist_data",
"media",
])
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ import numpy as np
from typing import List, Any, Dict, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
@ -152,15 +152,15 @@ class MultiAgentSampleBatchBuilder:
self.agent_builders[agent_id].add_values(**values)
def postprocess_batch_so_far(
self, episode: Optional[MultiAgentEpisode] = None) -> None:
def postprocess_batch_so_far(self,
episode: Optional[Episode] = None) -> None:
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Args:
episode (Optional[MultiAgentEpisode]): The Episode object that
episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object.
"""
@ -234,15 +234,15 @@ class MultiAgentSampleBatchBuilder:
"Alternatively, set no_done_at_end=True to allow this.")
@DeveloperAPI
def build_and_reset(self, episode: Optional[MultiAgentEpisode] = None
) -> MultiAgentBatch:
def build_and_reset(self,
episode: Optional[Episode] = None) -> MultiAgentBatch:
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Args:
episode (Optional[MultiAgentEpisode]): The Episode object that
episode (Optional[Episode]): The Episode object that
holds this MultiAgentBatchBuilder object or None.
Returns:

View file

@ -14,8 +14,8 @@ from ray.rllib.evaluation.collectors.sample_collector import \
SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import \
SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
@ -110,16 +110,45 @@ class SamplerInput(InputReader, metaclass=ABCMeta):
@abstractmethod
@DeveloperAPI
def get_data(self) -> SampleBatchType:
"""Called by `self.next()` to return the next batch of data.
Override this in child classes.
Returns:
The next batch of data.
"""
raise NotImplementedError
@abstractmethod
@DeveloperAPI
def get_metrics(self) -> List[RolloutMetrics]:
"""Returns list of episode metrics since the last call to this method.
The list will contain one RolloutMetrics object per completed episode.
Returns:
List of RolloutMetrics objects, one per completed episode since
the last call to this method.
"""
raise NotImplementedError
@abstractmethod
@DeveloperAPI
def get_extra_batches(self) -> List[SampleBatchType]:
"""Returns list of extra batches since the last call to this method.
The list will contain all SampleBatches or
MultiAgentBatches that the user has provided thus-far. Users can
add these "extra batches" to an episode by calling the episode's
`add_extra_batch([SampleBatchType])` method. This can be done from
inside an overridden `Policy.compute_actions_from_input_dict(...,
episodes)` or from a custom callback's `on_episode_[start|step|end]()`
methods.
Returns:
List of SamplesBatches or MultiAgentBatches provided thus-far by
the user since the last call to this method.
"""
raise NotImplementedError
@ -143,7 +172,7 @@ class SyncSampler(SamplerInput):
clip_actions: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
@ -153,41 +182,44 @@ class SyncSampler(SamplerInput):
obs_filters=None,
tf_sess=None,
):
"""Initializes a SyncSampler object.
"""Initializes a SyncSampler instance.
Args:
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
worker: The RolloutWorker that will use this Sampler for sampling.
env: Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards: True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no
clipping.
rollout_fragment_length (int): The length of a fragment to collect
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
callbacks (Callbacks): The Callbacks object to use when episode
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
callbacks: The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
multiple_episodes_in_batch (bool): Whether to pack multiple
horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions (bool): Whether to clip actions according to the
clip_actions: Whether to clip actions according to the
given action_space's bounds.
soft_horizon (bool): If True, calculate bootstrapped values as if
soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing
observations.
sample_collector_class (Optional[Type[SampleCollector]]): An
optional Samplecollector sub-class to use to collect, store,
and retrieve environment-, model-, and sampler data.
render (bool): Whether to try to render the environment after each
step.
observation_fn: Optional multi-agent observation func to use for
preprocessing observations.
sample_collector_class: An optional Samplecollector sub-class to
use to collect, store, and retrieve environment-, model-,
and sampler data.
render: Whether to try to render the environment after each step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
@ -262,8 +294,9 @@ class SyncSampler(SamplerInput):
class AsyncSampler(threading.Thread, SamplerInput):
"""Async SamplerInput that collects experiences in thread and queues them.
Once started, experiences are continuously collected and put into a Queue,
from where they can be unqueued by the caller of `get_data()`.
Once started, experiences are continuously collected in the background
and put into a Queue, from where they can be unqueued by the caller
of `get_data()`.
"""
def __init__(
@ -279,12 +312,12 @@ class AsyncSampler(threading.Thread, SamplerInput):
multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True,
clip_actions: bool = False,
blackhole_outputs: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
blackhole_outputs: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
@ -292,45 +325,44 @@ class AsyncSampler(threading.Thread, SamplerInput):
obs_filters=None,
tf_sess=None,
):
"""Initializes a AsyncSampler object.
"""Initializes an AsyncSampler instance.
Args:
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
worker: The RolloutWorker that will use this Sampler for sampling.
env: Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards: True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no
clipping.
rollout_fragment_length (int): The length of a fragment to collect
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by (str): Either "env_steps" or "agent_steps".
Refers to the unit of `rollout_fragment_length`.
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions (bool): Whether to clip actions according to the
clip_actions: Whether to clip actions according to the
given action_space's bounds.
blackhole_outputs (bool): Whether to collect samples, but then
blackhole_outputs: Whether to collect samples, but then
not further process or store them (throw away all samples).
soft_horizon (bool): If True, calculate bootstrapped values as if
soft_horizon: If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
no_done_at_end: Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing
observations.
sample_collector_class (Optional[Type[SampleCollector]]): An
optional Samplecollector sub-class to use to collect, store,
and retrieve environment-, model-, and sampler data.
render (bool): Whether to try to render the environment after each
step.
observation_fn: Optional multi-agent observation func to use for
preprocessing observations.
sample_collector_class: An optional SampleCollector sub-class to
use to collect, store, and retrieve environment-, model-,
and sampler data.
render: Whether to try to render the environment after each step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
@ -467,32 +499,32 @@ def _env_runner(
"""This implements the common experience collection logic.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
extra_batch_callback: function to send extra batch data to.
horizon: Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
normalize_actions (bool): Whether to normalize actions to the action
normalize_actions: Whether to normalize actions to the action
space's bounds.
clip_actions (bool): Whether to clip actions to the space range.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
perf_stats (_PerfStats): Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the
clip_actions: Whether to clip actions to the space range.
callbacks: User callbacks to run on episode events.
perf_stats: Record perf stats into this object.
soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode
no_done_at_end: Ignore the done=True at the end of the episode
and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent
observation_fn: Optional multi-agent
observation func to use for preprocessing observations.
sample_collector (Optional[SampleCollector]): An optional
sample_collector: An optional
SampleCollector object to use.
render (bool): Whether to try to render the environment after each
render: Whether to try to render the environment after each
step.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
terminal condition, and other fields as dictated by `policy`.
Object containing state, action, reward, terminal condition,
and other fields as dictated by `policy`.
"""
# May be populated with used for image rendering
@ -546,7 +578,7 @@ def _env_runner(
return None
def new_episode(env_id):
episode = MultiAgentEpisode(
episode = Episode(
worker.policy_map,
worker.policy_mapping_fn,
get_batch_builder,
@ -576,7 +608,7 @@ def _env_runner(
)
return episode
active_episodes: Dict[EnvID, MultiAgentEpisode] = \
active_episodes: Dict[EnvID, Episode] = \
NewEpisodeDefaultDict(new_episode)
while True:
@ -684,7 +716,7 @@ def _process_observations(
*,
worker: "RolloutWorker",
base_env: BaseEnv,
active_episodes: Dict[EnvID, MultiAgentEpisode],
active_episodes: Dict[EnvID, Episode],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]],
dones: Dict[EnvID, Dict[AgentID, bool]],
@ -701,38 +733,37 @@ def _process_observations(
"""Record new data from the environment and prepare for policy evaluation.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
active_episodes: Mapping from
episode ID to currently ongoing Episode object.
unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
call.
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
rewards: Doubly keyed dict of env-ids -> agent ids ->
rewards tensor, returned by a `BaseEnv.poll()` call.
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
dones: Doubly keyed dict of env-ids -> agent ids ->
boolean done flags, returned by a `BaseEnv.poll()` call.
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
infos: Doubly keyed dict of env-ids -> agent ids ->
info dicts, returned by a `BaseEnv.poll()` call.
horizon (int): Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple
horizon: Horizon of the episode.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
soft_horizon (bool): Calculate rewards but don't reset the
callbacks: User callbacks to run on episode events.
soft_horizon: Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode
no_done_at_end: Ignore the done=True at the end of the episode
and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent
observation_fn: Optional multi-agent
observation func to use for preprocessing observations.
sample_collector (SampleCollector): The SampleCollector object
sample_collector: The SampleCollector object
used to store and retrieve environment samples.
Returns:
Tuple:
- active_envs: Set of non-terminated env ids.
- to_eval: Map of policy_id to list of agent PolicyEvalData.
- outputs: List of metrics and samples to return from the sampler.
Tuple consisting of 1) active_envs: Set of non-terminated env ids.
2) to_eval: Map of policy_id to list of agent PolicyEvalData.
3) outputs: List of metrics and samples to return from the sampler.
"""
# Output objects.
@ -744,7 +775,7 @@ def _process_observations(
# types: EnvID, Dict[AgentID, EnvObsType]
for env_id, all_agents_obs in unfiltered_obs.items():
is_new_episode: bool = env_id not in active_episodes
episode: MultiAgentEpisode = active_episodes[env_id]
episode: Episode = active_episodes[env_id]
if not is_new_episode:
sample_collector.episode_step(episode)
@ -854,9 +885,11 @@ def _process_observations(
# Next observation.
SampleBatch.NEXT_OBS: filtered_obs,
}
# Add extra-action-fetches to collectors.
# Add extra-action-fetches (policy-inference infos) to
# collectors.
pol = worker.policy_map[policy_id]
for key, value in episode.last_pi_info_for(agent_id).items():
for key, value in episode.last_extra_action_outs_for(
agent_id).items():
if key in pol.view_requirements:
values_dict[key] = value
# Env infos for this agent.
@ -947,7 +980,7 @@ def _process_observations(
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll.
elif resetted_obs != ASYNC_RESET_RETURN:
new_episode: MultiAgentEpisode = active_episodes[env_id]
new_episode: Episode = active_episodes[env_id]
if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs,
@ -995,21 +1028,20 @@ def _do_policy_eval(
to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: PolicyMap,
sample_collector,
active_episodes: Dict[EnvID, MultiAgentEpisode],
active_episodes: Dict[EnvID, Episode],
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action.
Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
IDs to lists of PolicyEvalData objects (items in these lists will
be the batch's items for the model forward pass).
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
obj.
sample_collector (SampleCollector): The SampleCollector object to use.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping of
to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
(items in these lists will be the batch's items for the model
forward pass).
policies: Mapping from policy ID to Policy obj.
sample_collector: The SampleCollector object to use.
active_episodes: Mapping of EnvID to its currently active episode.
Returns:
eval_results: dict of policy to compute_action() outputs.
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
"""
eval_results: Dict[PolicyID, TensorStructType] = {}
@ -1052,7 +1084,7 @@ def _process_policy_eval_results(
to_eval: Dict[PolicyID, List[PolicyEvalData]],
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
dict]],
active_episodes: Dict[EnvID, MultiAgentEpisode],
active_episodes: Dict[EnvID, Episode],
active_envs: Set[int],
off_policy_actions: MultiEnvDict,
policies: Dict[PolicyID, Policy],
@ -1065,24 +1097,22 @@ def _process_policy_eval_results(
returns replies to send back to agents in the env.
Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs
to lists of PolicyEvalData objects.
eval_results (Dict[PolicyID, List]): Mapping of policy IDs to list of
to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
eval_results: Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts.
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
active_envs (Set[int]): Set of non-terminated env ids.
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids ->
active_episodes: Mapping from episode ID to currently ongoing
Episode object.
active_envs: Set of non-terminated env ids.
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call.
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
normalize_actions (bool): Whether to normalize actions to the action
policies: Mapping from policy ID to Policy.
normalize_actions: Whether to normalize actions to the action
space's bounds.
clip_actions (bool): Whether to clip actions to the action space's
bounds.
clip_actions: Whether to clip actions to the action space's bounds.
Returns:
actions_to_send: Nested dict of env id -> agent id -> actions to be
sent to Env (np.ndarrays).
Nested dict of env id -> agent id -> actions to be sent to
Env (np.ndarrays).
"""
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
@ -1098,7 +1128,7 @@ def _process_policy_eval_results(
actions = convert_to_numpy(actions)
rnn_out_cols: StateBatch = eval_results[policy_id][1]
pi_info_cols: dict = eval_results[policy_id][2]
extra_action_out_cols: dict = eval_results[policy_id][2]
# In case actions is a list (representing the 0th dim of a batch of
# primitive actions), try converting it first.
@ -1107,7 +1137,7 @@ def _process_policy_eval_results(
# Store RNN state ins/outs and extra-action fetches to episode.
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
extra_action_out_cols["state_out_{}".format(f_i)] = column
policy: Policy = _get_or_raise(policies, policy_id)
# Split action-component batches into single action rows.
@ -1127,11 +1157,11 @@ def _process_policy_eval_results(
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
episode: MultiAgentEpisode = active_episodes[env_id]
episode: Episode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
episode._set_last_extra_action_outs(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
for k, v in extra_action_out_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id,

View file

@ -81,7 +81,7 @@ class EchoPolicy(Policy):
return obs_batch.argmax(axis=1), [], {}
class MultiAgentEpisodeEnv(MultiAgentEnv):
class EpisodeEnv(MultiAgentEnv):
def __init__(self, episode_length, num):
self.agents = [MockEnv3(episode_length) for _ in range(num)]
self.dones = set()
@ -123,9 +123,9 @@ class TestEpisodeLastValues(unittest.TestCase):
ev.sample()
def test_multiagent_env(self):
temp_env = MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS)
temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS)
ev = RolloutWorker(
env_creator=lambda _: MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS),
env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS),
policy_spec={
str(agent_id): (EchoPolicy, temp_env.observation_space,
temp_env.action_space, {})

View file

@ -29,9 +29,9 @@ T = TypeVar("T")
@DeveloperAPI
class WorkerSet:
"""Represents a set of RolloutWorkers.
"""Set of RolloutWorkers with n @ray.remote workers and one local worker.
There must be one local worker copy, and zero or more remote workers.
Where n may be 0.
"""
def __init__(self,
@ -43,20 +43,20 @@ class WorkerSet:
num_workers: int = 0,
logdir: Optional[str] = None,
_setup: bool = True):
"""Create a new WorkerSet and initialize its workers.
"""Initializes a WorkerSet instance.
Args:
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
that returns env given env config.
validate_env (Optional[Callable[[EnvType], None]]): Optional
callable to validate the generated environment (only on
worker=0).
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
trainer_config (Optional[TrainerConfigDict]): Optional dict that
extends the common config of the Trainer class.
num_workers (int): Number of remote rollout workers to create.
logdir (Optional[str]): Optional logging directory for workers.
_setup (bool): Whether to setup workers. This is only for testing.
env_creator: Function that returns env given env config.
validate_env: Optional callable to validate the generated
environment (only on worker=0).
policy_class: An optional Policy class. If None, PolicySpecs can be
generated automatically by using the Trainer's default class
of via a given multi-agent policy config dict.
trainer_config: Optional dict that extends the common config of
the Trainer class.
num_workers: Number of remote rollout workers to create.
logdir: Optional logging directory for workers.
_setup: Whether to setup workers. This is only for testing.
"""
if not trainer_config:
@ -119,15 +119,20 @@ class WorkerSet:
)
def local_worker(self) -> RolloutWorker:
"""Return the local rollout worker."""
"""Returns the local rollout worker."""
return self._local_worker
def remote_workers(self) -> List[ActorHandle]:
"""Return a list of remote rollout workers."""
"""Returns a list of remote rollout workers."""
return self._remote_workers
def sync_weights(self, policies: Optional[List[PolicyID]] = None) -> None:
"""Syncs weights from the local worker to all remote workers."""
"""Syncs model weights from the local worker to all remote workers.
Args:
policies: An optional list of policy IDs to sync for. If None,
sync all policies.
"""
if self.remote_workers():
weights = ray.put(self.local_worker().get_weights(policies))
for e in self.remote_workers():
@ -136,8 +141,11 @@ class WorkerSet:
def add_workers(self, num_workers: int) -> None:
"""Creates and adds a number of remote workers to this worker set.
Can be called several times on the same WorkerSet to add more
RolloutWorkers to the set.
Args:
num_workers (int): The number of remote Workers to add to this
num_workers: The number of remote Workers to add to this
WorkerSet.
"""
remote_args = {
@ -159,25 +167,36 @@ class WorkerSet:
])
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
"""Called to change the set of remote workers."""
"""Hard overrides the remote workers in this set with the given one.
Args:
new_remote_workers: A list of new RolloutWorkers
(as `ActorHandles`) to use as remote workers.
"""
self._remote_workers = new_remote_workers
def stop(self) -> None:
"""Stop all rollout workers."""
"""Calls `stop` on all rollout workers (including the local one)."""
try:
self.local_worker().stop()
tids = [w.stop.remote() for w in self.remote_workers()]
ray.get(tids)
except Exception:
logger.exception("Failed to stop workers")
logger.exception("Failed to stop workers!")
finally:
for w in self.remote_workers():
w.__ray_terminate__.remote()
@DeveloperAPI
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
"""Apply the given function to each worker instance."""
"""Calls the given function with each worker instance as arg.
Args:
func: The function to call for each worker (as only arg).
Returns:
The list of return values of all calls to `func([worker])`.
"""
local_result = [func(self.local_worker())]
remote_results = ray.get(
[w.apply.remote(func) for w in self.remote_workers()])
@ -186,11 +205,23 @@ class WorkerSet:
@DeveloperAPI
def foreach_worker_with_index(
self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
"""Apply the given function to each worker instance.
"""Calls `func` with each worker instance and worker idx as args.
The index will be passed as the second arg to the given function.
Args:
func: The function to call for each worker and its index
(as args). The local worker has index 0, all remote workers
have indices > 0.
Returns:
The list of return values of all calls to `func([worker, idx])`.
The first entry in this list are the results of the local
worker, followed by all remote workers' results.
"""
# Local worker: Index=0.
local_result = [func(self.local_worker(), 0)]
# Remote workers: Index > 0.
remote_results = ray.get([
w.apply.remote(func, i + 1)
for i, w in enumerate(self.remote_workers())
@ -199,15 +230,22 @@ class WorkerSet:
@DeveloperAPI
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
"""Apply the given function to each worker's (policy, policy_id) tuple.
"""Calls `func` with each worker's (policy, PolicyID) tuple.
Note that in the multi-agent case, each worker may have more than one
policy.
Args:
func (callable): A function - taking a Policy and its ID - that is
func: A function - taking a Policy and its ID - that is
called on all workers' Policies.
Returns:
List[any]: The list of return values of func over all workers'
policies.
The list of return values of func over all workers' policies. The
length of this list is:
(num_workers + 1 (local-worker)) *
[num policies in the multi-agent config dict].
The local workers' results are first, followed by all remote
workers' results
"""
results = self.local_worker().foreach_policy(func)
ray_gets = []
@ -221,8 +259,8 @@ class WorkerSet:
@DeveloperAPI
def trainable_policies(self) -> List[PolicyID]:
"""Return the list of trainable policy ids."""
return self.local_worker().foreach_trainable_policy(lambda _, pid: pid)
"""Returns the list of trainable policy ids."""
return self.local_worker().policies_to_train
@DeveloperAPI
def foreach_trainable_policy(
@ -230,7 +268,7 @@ class WorkerSet:
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
Args:
func (callable): A function - taking a Policy and its ID - that is
func: A function - taking a Policy and its ID - that is
called on all workers' Policies in `worker.policies_to_train`.
Returns:
@ -250,7 +288,7 @@ class WorkerSet:
@DeveloperAPI
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
"""Apply `func` to all workers' underlying sub environments.
"""Calls `func` with all workers' sub environments as args.
An "underlying sub environment" is a single clone of an env within
a vectorized environment.
@ -258,13 +296,12 @@ class WorkerSet:
gym.Env object.
Args:
func (Callable[[EnvType], T]): A function - taking an EnvType
(normally a gym.Env object) as arg and returning a list of
return values over sub environments for each worker.
func: A function - taking an EnvType (normally a gym.Env object)
as arg and returning a list of lists of return values, one
value per underlying sub-environment per each worker.
Returns:
List[List[T]]: The list (workers) of lists (sub environments) of
results.
The list (workers) of lists (sub environments) of results.
"""
local_results = [self.local_worker().foreach_env(func)]
ray_gets = []
@ -276,7 +313,7 @@ class WorkerSet:
def foreach_env_with_context(
self,
func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
"""Apply `func` to all workers' underlying sub environments.
"""Call `func` with all workers' sub-environments and env_ctx as args.
An "underlying sub environment" is a single clone of an env within
a vectorized environment.
@ -284,13 +321,13 @@ class WorkerSet:
as args.
Args:
func (Callable[[BaseEnv], T]): A function - taking a BaseEnv
object as arg and returning a list of return values over envs
func: A function - taking a BaseEnv object and an EnvContext as
arg - and returning a list of lists of return values over envs
of the worker.
Returns:
List[List[T]]: The list (workers) of lists (environments) of
results.
The list (1 item per workers) of lists (1 item per sub-environment)
of results.
"""
local_results = [self.local_worker().foreach_env_with_context(func)]
ray_gets = []

View file

@ -13,7 +13,7 @@ import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
@ -28,8 +28,8 @@ parser.add_argument("--stop-iters", type=int, default=2000)
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
policies: Dict[str, Policy], episode: Episode,
env_index: int, **kwargs):
# Make sure this episode has just been started (only initial obs
# logged so far).
assert episode.length == 0, \
@ -41,8 +41,8 @@ class MyCallbacks(DefaultCallbacks):
episode.hist_data["pole_angles"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
policies: Dict[str, Policy], episode: Episode,
env_index: int, **kwargs):
# Make sure this episode is ongoing.
assert episode.length > 0, \
"ERROR: `on_episode_step()` callback should not be called right " \
@ -53,7 +53,7 @@ class MyCallbacks(DefaultCallbacks):
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
policies: Dict[str, Policy], episode: Episode,
env_index: int, **kwargs):
# Make sure this episode is really done.
assert episode.batch_builder.policy_collectors[
@ -84,8 +84,8 @@ class MyCallbacks(DefaultCallbacks):
policy, result["sum_actions_in_train_batch"]))
def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
self, *, worker: RolloutWorker, episode: Episode, agent_id: str,
policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed {} steps".format(postprocessed_batch.count))

View file

@ -23,7 +23,7 @@ tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.evaluation import Episode
logger = logging.getLogger(__name__)
@ -142,7 +142,7 @@ class Policy(metaclass=ABCMeta):
prev_reward: Optional[TensorStructType] = None,
info: dict = None,
input_dict: Optional[SampleBatch] = None,
episode: Optional["MultiAgentEpisode"] = None,
episode: Optional["Episode"] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
# Kwars placeholder for future compatibility.
@ -238,7 +238,7 @@ class Policy(metaclass=ABCMeta):
input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
explore: bool = None,
timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
episodes: Optional[List["Episode"]] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents).
@ -300,7 +300,7 @@ class Policy(metaclass=ABCMeta):
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
@ -313,7 +313,7 @@ class Policy(metaclass=ABCMeta):
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
info_batch: Batch of info objects.
episodes: List of MultiAgentEpisodes, one for each obs in
episodes: List of Episode objects, one for each obs in
obs_batch. This provides access to all of the internal
episode state, which may be useful for model-based or
multi-agent algorithms.
@ -377,7 +377,7 @@ class Policy(metaclass=ABCMeta):
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, Tuple[
"Policy", SampleBatch]]] = None,
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch:
episode: Optional["Episode"] = None) -> SampleBatch:
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy

View file

@ -19,7 +19,7 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa
from ray.rllib.evaluation.episode import Episode # noqa
jax, _ = try_import_jax()
torch, _ = try_import_torch()
@ -39,7 +39,7 @@ def build_policy_class(
str, TensorType]]] = None,
postprocess_fn: Optional[Callable[[
Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
"MultiAgentEpisode"]
"Episode"]
], SampleBatch]] = None,
extra_action_out_fn: Optional[Callable[[
Policy, Dict[str, TensorType], List[TensorType], ModelV2,
@ -98,7 +98,7 @@ def build_policy_class(
overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]],
Optional[Dict[Any, SampleBatch]], Optional["Episode"]],
SampleBatch]]): Optional callable for post-processing experience
batches (called after the super's `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch],

View file

@ -28,7 +28,7 @@ from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
TensorType, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.evaluation import Episode
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@ -277,7 +277,7 @@ class TFPolicy(Policy):
input_dict: Union[SampleBatch, Dict[str, TensorType]],
explore: bool = None,
timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
episodes: Optional[List["Episode"]] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
@ -308,7 +308,7 @@ class TFPolicy(Policy):
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs):

View file

@ -18,7 +18,7 @@ from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \
TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.evaluation import Episode
tf1, tf, tfv = try_import_tf()
@ -34,7 +34,7 @@ def build_tf_policy(
TrainerConfigDict]] = None,
postprocess_fn: Optional[Callable[[
Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
Optional["MultiAgentEpisode"]
Optional["Episode"]
], SampleBatch]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None,
@ -107,7 +107,7 @@ def build_tf_policy(
overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
Optional[Dict[AgentID, SampleBatch]], MultiAgentEpisode], None]]):
Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
Optional callable for post-processing experience batches (called
after the parent class' `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch],

View file

@ -31,7 +31,7 @@ from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TensorStructType, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa
from ray.rllib.evaluation import Episode # noqa
torch, nn = try_import_torch()
@ -279,7 +279,7 @@ class TorchPolicy(Policy):
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \

View file

@ -7,7 +7,7 @@ import ray
from ray.tune.registry import register_env
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
@ -336,9 +336,9 @@ class TestMultiAgentEnv(unittest.TestCase):
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
env_id = episodes[0].env_id
fake_eps = MultiAgentEpisode(
episodes[0].policy_map, episodes[0].policy_mapping_fn,
lambda: None, lambda x: None, env_id)
fake_eps = Episode(episodes[0].policy_map,
episodes[0].policy_mapping_fn,
lambda: None, lambda x: None, env_id)
builder = get_global_worker().sampler.sample_collector
agent_id = "extra_0"
policy_id = "p1" # use p1 so we can easily check it