mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib; Docs overhaul] Docstring cleanup: Evaluation (#19783)
This commit is contained in:
parent
f2773267c7
commit
9c73871da0
30 changed files with 1059 additions and 705 deletions
|
@ -5,20 +5,20 @@ import gym
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
||||||
EntropyCoeffSchedule
|
EntropyCoeffSchedule
|
||||||
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
from ray.rllib.utils.annotations import Deprecated
|
from ray.rllib.utils.annotations import Deprecated
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.tf_ops import explained_variance
|
from ray.rllib.utils.tf_ops import explained_variance
|
||||||
from ray.rllib.policy.policy import Policy
|
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||||
PolicyID, LocalOptimizer, ModelGradients
|
PolicyID, LocalOptimizer, ModelGradients
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def postprocess_advantages(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
|
|
||||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
other_agent_batches, episode)
|
other_agent_batches, episode)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional, Dict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
|
@ -30,7 +30,7 @@ def add_advantages(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
|
|
||||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
other_agent_batches, episode)
|
other_agent_batches, episode)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, Optional, TYPE_CHECKING
|
||||||
from ray.rllib.env import BaseEnv
|
from ray.rllib.env import BaseEnv
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.utils.annotations import PublicAPI
|
from ray.rllib.utils.annotations import PublicAPI
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
from ray.rllib.utils.typing import AgentID, PolicyID
|
||||||
|
@ -35,28 +35,23 @@ class DefaultCallbacks:
|
||||||
"a class extending rllib.agents.callbacks.DefaultCallbacks")
|
"a class extending rllib.agents.callbacks.DefaultCallbacks")
|
||||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||||
|
|
||||||
def on_episode_start(self,
|
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
|
||||||
*,
|
policies: Dict[PolicyID, Policy], episode: Episode,
|
||||||
worker: "RolloutWorker",
|
|
||||||
base_env: BaseEnv,
|
|
||||||
policies: Dict[PolicyID, Policy],
|
|
||||||
episode: MultiAgentEpisode,
|
|
||||||
env_index: Optional[int] = None,
|
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""Callback run on the rollout worker before each episode starts.
|
"""Callback run on the rollout worker before each episode starts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker: Reference to the current rollout worker.
|
worker: Reference to the current rollout worker.
|
||||||
base_env: BaseEnv running the episode. The underlying
|
base_env: BaseEnv running the episode. The underlying
|
||||||
sub environment objects can be received by calling
|
sub environment objects can be retrieved by calling
|
||||||
`base_env.get_sub_environments()`.
|
`base_env.get_sub_environments()`.
|
||||||
policies: Mapping of policy id to policy objects. In single
|
policies: Mapping of policy id to policy objects. In single
|
||||||
agent mode there will only be a single "default" policy.
|
agent mode there will only be a single "default" policy.
|
||||||
episode (MultiAgentEpisode): Episode object which contains episode
|
episode: Episode object which contains the episode's
|
||||||
state. You can use the `episode.user_data` dict to store
|
state. You can use the `episode.user_data` dict to store
|
||||||
temporary data, and `episode.custom_metrics` to store custom
|
temporary data, and `episode.custom_metrics` to store custom
|
||||||
metrics for the episode.
|
metrics for the episode.
|
||||||
env_index (EnvID): Obsoleted: The ID of the environment, which the
|
env_index: Obsoleted: The ID of the environment, which the
|
||||||
episode belongs to.
|
episode belongs to.
|
||||||
kwargs: Forward compatibility placeholder.
|
kwargs: Forward compatibility placeholder.
|
||||||
"""
|
"""
|
||||||
|
@ -73,20 +68,19 @@ class DefaultCallbacks:
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
env_index: Optional[int] = None,
|
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""Runs on each episode step.
|
"""Runs on each episode step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): Reference to the current rollout worker.
|
worker (RolloutWorker): Reference to the current rollout worker.
|
||||||
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
||||||
sub environment objects can be gotten by calling
|
sub environment objects can be retrieved by calling
|
||||||
`base_env.get_sub_environments()`.
|
`base_env.get_sub_environments()`.
|
||||||
policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id
|
policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id
|
||||||
to policy objects. In single agent mode there will only be a
|
to policy objects. In single agent mode there will only be a
|
||||||
single "default_policy".
|
single "default_policy".
|
||||||
episode (MultiAgentEpisode): Episode object which contains episode
|
episode (Episode): Episode object which contains episode
|
||||||
state. You can use the `episode.user_data` dict to store
|
state. You can use the `episode.user_data` dict to store
|
||||||
temporary data, and `episode.custom_metrics` to store custom
|
temporary data, and `episode.custom_metrics` to store custom
|
||||||
metrics for the episode.
|
metrics for the episode.
|
||||||
|
@ -101,13 +95,8 @@ class DefaultCallbacks:
|
||||||
"episode": episode
|
"episode": episode
|
||||||
})
|
})
|
||||||
|
|
||||||
def on_episode_end(self,
|
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
|
||||||
*,
|
policies: Dict[PolicyID, Policy], episode: Episode,
|
||||||
worker: "RolloutWorker",
|
|
||||||
base_env: BaseEnv,
|
|
||||||
policies: Dict[PolicyID, Policy],
|
|
||||||
episode: MultiAgentEpisode,
|
|
||||||
env_index: Optional[int] = None,
|
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
"""Runs when an episode is done.
|
"""Runs when an episode is done.
|
||||||
|
|
||||||
|
@ -119,7 +108,7 @@ class DefaultCallbacks:
|
||||||
policies (Dict[PolicyID, Policy]): Mapping of policy id to policy
|
policies (Dict[PolicyID, Policy]): Mapping of policy id to policy
|
||||||
objects. In single agent mode there will only be a single
|
objects. In single agent mode there will only be a single
|
||||||
"default_policy".
|
"default_policy".
|
||||||
episode (MultiAgentEpisode): Episode object which contains episode
|
episode (Episode): Episode object which contains episode
|
||||||
state. You can use the `episode.user_data` dict to store
|
state. You can use the `episode.user_data` dict to store
|
||||||
temporary data, and `episode.custom_metrics` to store custom
|
temporary data, and `episode.custom_metrics` to store custom
|
||||||
metrics for the episode.
|
metrics for the episode.
|
||||||
|
@ -136,7 +125,7 @@ class DefaultCallbacks:
|
||||||
})
|
})
|
||||||
|
|
||||||
def on_postprocess_trajectory(
|
def on_postprocess_trajectory(
|
||||||
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
self, *, worker: "RolloutWorker", episode: Episode,
|
||||||
agent_id: AgentID, policy_id: PolicyID,
|
agent_id: AgentID, policy_id: PolicyID,
|
||||||
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
||||||
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
|
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
|
||||||
|
@ -148,7 +137,7 @@ class DefaultCallbacks:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): Reference to the current rollout worker.
|
worker (RolloutWorker): Reference to the current rollout worker.
|
||||||
episode (MultiAgentEpisode): Episode object.
|
episode (Episode): Episode object.
|
||||||
agent_id (str): Id of the current agent.
|
agent_id (str): Id of the current agent.
|
||||||
policy_id (str): Id of the current policy for the agent.
|
policy_id (str): Id of the current policy for the agent.
|
||||||
policies (dict): Mapping of policy id to policy objects. In single
|
policies (dict): Mapping of policy id to policy objects. In single
|
||||||
|
@ -253,7 +242,7 @@ class MemoryTrackingCallbacks(DefaultCallbacks):
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
env_index: Optional[int] = None,
|
env_index: Optional[int] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
snapshot = tracemalloc.take_snapshot()
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
@ -311,7 +300,7 @@ class MultiCallbacks(DefaultCallbacks):
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
env_index: Optional[int] = None,
|
env_index: Optional[int] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
for callback in self._callback_list:
|
for callback in self._callback_list:
|
||||||
|
@ -328,7 +317,7 @@ class MultiCallbacks(DefaultCallbacks):
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
env_index: Optional[int] = None,
|
env_index: Optional[int] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
for callback in self._callback_list:
|
for callback in self._callback_list:
|
||||||
|
@ -345,7 +334,7 @@ class MultiCallbacks(DefaultCallbacks):
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
env_index: Optional[int] = None,
|
env_index: Optional[int] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
for callback in self._callback_list:
|
for callback in self._callback_list:
|
||||||
|
@ -358,7 +347,7 @@ class MultiCallbacks(DefaultCallbacks):
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def on_postprocess_trajectory(
|
def on_postprocess_trajectory(
|
||||||
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
self, *, worker: "RolloutWorker", episode: Episode,
|
||||||
agent_id: AgentID, policy_id: PolicyID,
|
agent_id: AgentID, policy_id: PolicyID,
|
||||||
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
||||||
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
|
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
|
@ -247,7 +247,7 @@ def preprocess_episode(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
|
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
|
||||||
When t=0, the resetted obs is paired with action and reward of 0.
|
When t=0, the resetted obs is paired with action and reward of 0.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -59,7 +59,7 @@ def postprocess_advantages(
|
||||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||||
same episode). NOTE: The other agents use the same policy.
|
same episode). NOTE: The other agents use the same policy.
|
||||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
episode (Optional[Episode]): Optional multi-agent episode
|
||||||
object in which the agents operated.
|
object in which the agents operated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
@ -10,7 +10,7 @@ def post_process_advantages(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[List[SampleBatch]] = None,
|
other_agent_batches: Optional[List[SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
"""Adds the "advantages" column to `sample_batch`.
|
"""Adds the "advantages" column to `sample_batch`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -18,7 +18,7 @@ def post_process_advantages(
|
||||||
sample_batch (SampleBatch): The actual sample batch to post-process.
|
sample_batch (SampleBatch): The actual sample batch to post-process.
|
||||||
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
|
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
|
||||||
other agents' SampleBatch objects.
|
other agents' SampleBatch objects.
|
||||||
episode (MultiAgentEpisode): The multi-agent episode object, from which
|
episode (Episode): The multi-agent episode object, from which
|
||||||
`sample_batch` was generated.
|
`sample_batch` was generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Type, Union
|
||||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||||
clip_gradients, choose_optimizer
|
clip_gradients, choose_optimizer
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||||
|
@ -325,7 +325,7 @@ def postprocess_trajectory(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||||
|
|
||||||
The trajectory contains only data from one episode and from one agent.
|
The trajectory contains only data from one episode and from one agent.
|
||||||
|
@ -343,7 +343,7 @@ def postprocess_trajectory(
|
||||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||||
same episode). NOTE: The other agents use the same policy.
|
same episode). NOTE: The other agents use the same policy.
|
||||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
episode (Optional[Episode]): Optional multi-agent episode
|
||||||
object in which the agents operated.
|
object in which the agents operated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -7,7 +7,7 @@ import logging
|
||||||
from typing import Dict, List, Optional, Type, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
|
@ -357,7 +357,7 @@ def postprocess_ppo_gae(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
|
|
||||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
other_agent_batches, episode)
|
other_agent_batches, episode)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||||
PRIO_WEIGHTS
|
PRIO_WEIGHTS
|
||||||
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
||||||
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
|
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
|
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
|
||||||
|
@ -106,7 +106,7 @@ def postprocess_trajectory(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||||
|
|
||||||
The trajectory contains only data from one episode and from one agent.
|
The trajectory contains only data from one episode and from one agent.
|
||||||
|
@ -124,7 +124,7 @@ def postprocess_trajectory(
|
||||||
other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
|
other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
|
||||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||||
same episode). NOTE: The other agents use the same policy.
|
same episode). NOTE: The other agents use the same policy.
|
||||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
episode (Optional[Episode]): Optional multi-agent episode
|
||||||
object in which the agents operated.
|
object in which the agents operated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.env.utils import gym_env_creator
|
from ray.rllib.env.utils import gym_env_creator
|
||||||
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
||||||
SimpleListCollector
|
SimpleListCollector
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||||
|
@ -1086,7 +1086,7 @@ class Trainer(Trainable):
|
||||||
full_fetch: bool = False,
|
full_fetch: bool = False,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None,
|
episode: Optional[Episode] = None,
|
||||||
unsquash_action: Optional[bool] = None,
|
unsquash_action: Optional[bool] = None,
|
||||||
clip_action: Optional[bool] = None,
|
clip_action: Optional[bool] = None,
|
||||||
|
|
||||||
|
@ -1240,7 +1240,7 @@ class Trainer(Trainable):
|
||||||
full_fetch: bool = False,
|
full_fetch: bool = False,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
episodes: Optional[List[MultiAgentEpisode]] = None,
|
episodes: Optional[List[Episode]] = None,
|
||||||
unsquash_actions: Optional[bool] = None,
|
unsquash_actions: Optional[bool] = None,
|
||||||
clip_actions: Optional[bool] = None,
|
clip_actions: Optional[bool] = None,
|
||||||
# Deprecated.
|
# Deprecated.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode, MultiAgentEpisode
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.evaluation.sample_batch_builder import (
|
from ray.rllib.evaluation.sample_batch_builder import (
|
||||||
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
|
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
|
||||||
|
@ -17,5 +17,6 @@ __all__ = [
|
||||||
"AsyncSampler",
|
"AsyncSampler",
|
||||||
"compute_advantages",
|
"compute_advantages",
|
||||||
"collect_metrics",
|
"collect_metrics",
|
||||||
"MultiAgentEpisode",
|
"Episode",
|
||||||
|
"MultiAgentEpisode", # Deprecated -> Use `Episode` instead.
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,7 +2,7 @@ from abc import abstractmethod, ABCMeta
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.policy.policy_map import PolicyMap
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||||
|
@ -57,7 +57,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
self.count_steps_by = count_steps_by
|
self.count_steps_by = count_steps_by
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
def add_init_obs(self, episode: Episode, agent_id: AgentID,
|
||||||
policy_id: PolicyID, t: int,
|
policy_id: PolicyID, t: int,
|
||||||
init_obs: TensorType) -> None:
|
init_obs: TensorType) -> None:
|
||||||
"""Adds an initial obs (after reset) to this collector.
|
"""Adds an initial obs (after reset) to this collector.
|
||||||
|
@ -70,7 +70,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
called for that same agent/episode-pair.
|
called for that same agent/episode-pair.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode (MultiAgentEpisode): The MultiAgentEpisode, for which we
|
episode (Episode): The Episode, for which we
|
||||||
are adding an Agent's initial observation.
|
are adding an Agent's initial observation.
|
||||||
agent_id (AgentID): Unique id for the agent we are adding
|
agent_id (AgentID): Unique id for the agent we are adding
|
||||||
values for.
|
values for.
|
||||||
|
@ -126,11 +126,11 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def episode_step(self, episode: MultiAgentEpisode) -> None:
|
def episode_step(self, episode: Episode) -> None:
|
||||||
"""Increases the episode step counter (across all agents) by one.
|
"""Increases the episode step counter (across all agents) by one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode (MultiAgentEpisode): Episode we are stepping through.
|
episode (Episode): Episode we are stepping through.
|
||||||
Useful for handling counting b/c it is called once across
|
Useful for handling counting b/c it is called once across
|
||||||
all agents that are inside this episode.
|
all agents that are inside this episode.
|
||||||
"""
|
"""
|
||||||
|
@ -200,7 +200,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def postprocess_episode(self,
|
def postprocess_episode(self,
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
is_done: bool = False,
|
is_done: bool = False,
|
||||||
check_dones: bool = False,
|
check_dones: bool = False,
|
||||||
build: bool = False) -> Optional[MultiAgentBatch]:
|
build: bool = False) -> Optional[MultiAgentBatch]:
|
||||||
|
@ -214,7 +214,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
correctly added to the buffers.
|
correctly added to the buffers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode (MultiAgentEpisode): The Episode object for which
|
episode (Episode): The Episode object for which
|
||||||
to post-process data.
|
to post-process data.
|
||||||
is_done (bool): Whether the given episode is actually terminated
|
is_done (bool): Whether the given episode is actually terminated
|
||||||
(all agents are done OR we hit a hard horizon). If True, the
|
(all agents are done OR we hit a hard horizon). If True, the
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||||
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.policy_map import PolicyMap
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||||
|
@ -505,11 +505,11 @@ class SimpleListCollector(SampleCollector):
|
||||||
# Maps episode ID to the (non-built) individual agent steps in this
|
# Maps episode ID to the (non-built) individual agent steps in this
|
||||||
# episode.
|
# episode.
|
||||||
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
||||||
# Maps episode ID to MultiAgentEpisode.
|
# Maps episode ID to Episode.
|
||||||
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
|
self.episodes: Dict[EpisodeID, Episode] = {}
|
||||||
|
|
||||||
@override(SampleCollector)
|
@override(SampleCollector)
|
||||||
def episode_step(self, episode: MultiAgentEpisode) -> None:
|
def episode_step(self, episode: Episode) -> None:
|
||||||
episode_id = episode.episode_id
|
episode_id = episode.episode_id
|
||||||
# In the rase case that an "empty" step is taken at the beginning of
|
# In the rase case that an "empty" step is taken at the beginning of
|
||||||
# the episode (none of the agents has an observation in the obs-dict
|
# the episode (none of the agents has an observation in the obs-dict
|
||||||
|
@ -550,8 +550,8 @@ class SimpleListCollector(SampleCollector):
|
||||||
if not self.multiple_episodes_in_batch else ""))
|
if not self.multiple_episodes_in_batch else ""))
|
||||||
|
|
||||||
@override(SampleCollector)
|
@override(SampleCollector)
|
||||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
def add_init_obs(self, episode: Episode, agent_id: AgentID, env_id: EnvID,
|
||||||
env_id: EnvID, policy_id: PolicyID, t: int,
|
policy_id: PolicyID, t: int,
|
||||||
init_obs: TensorType) -> None:
|
init_obs: TensorType) -> None:
|
||||||
# Make sure our mappings are up to date.
|
# Make sure our mappings are up to date.
|
||||||
agent_key = (episode.episode_id, agent_id)
|
agent_key = (episode.episode_id, agent_id)
|
||||||
|
@ -707,7 +707,7 @@ class SimpleListCollector(SampleCollector):
|
||||||
@override(SampleCollector)
|
@override(SampleCollector)
|
||||||
def postprocess_episode(
|
def postprocess_episode(
|
||||||
self,
|
self,
|
||||||
episode: MultiAgentEpisode,
|
episode: Episode,
|
||||||
is_done: bool = False,
|
is_done: bool = False,
|
||||||
check_dones: bool = False,
|
check_dones: bool = False,
|
||||||
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
|
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
|
||||||
|
@ -834,7 +834,7 @@ class SimpleListCollector(SampleCollector):
|
||||||
if build:
|
if build:
|
||||||
return self._build_multi_agent_batch(episode)
|
return self._build_multi_agent_batch(episode)
|
||||||
|
|
||||||
def _build_multi_agent_batch(self, episode: MultiAgentEpisode) -> \
|
def _build_multi_agent_batch(self, episode: Episode) -> \
|
||||||
Union[MultiAgentBatch, SampleBatch]:
|
Union[MultiAgentBatch, SampleBatch]:
|
||||||
|
|
||||||
ma_batch = {}
|
ma_batch = {}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Callable, Any, TYPE_CHECKING
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||||
from ray.rllib.policy.policy_map import PolicyMap
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||||
|
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class MultiAgentEpisode:
|
class Episode:
|
||||||
"""Tracks the current state of a (possibly multi-agent) episode.
|
"""Tracks the current state of a (possibly multi-agent) episode.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -53,15 +53,28 @@ class MultiAgentEpisode:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policies: PolicyMap,
|
policies: PolicyMap,
|
||||||
policy_mapping_fn: Callable[
|
policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
|
||||||
[AgentID, "MultiAgentEpisode", "RolloutWorker"], PolicyID],
|
PolicyID],
|
||||||
batch_builder_factory: Callable[[],
|
batch_builder_factory: Callable[[],
|
||||||
"MultiAgentSampleBatchBuilder"],
|
"MultiAgentSampleBatchBuilder"],
|
||||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||||
env_id: EnvID,
|
env_id: EnvID,
|
||||||
*,
|
*,
|
||||||
worker: "RolloutWorker" = None,
|
worker: Optional["RolloutWorker"] = None,
|
||||||
):
|
):
|
||||||
|
"""Initializes an Episode instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policies: The PolicyMap object (mapping PolicyIDs to Policy
|
||||||
|
objects) to use for determining, which policy is used for
|
||||||
|
which agent.
|
||||||
|
policy_mapping_fn: The mapping function mapping AgentIDs to
|
||||||
|
PolicyIDs.
|
||||||
|
batch_builder_factory:
|
||||||
|
extra_batch_callback:
|
||||||
|
env_id: The environment's ID in which this episode runs.
|
||||||
|
worker: The RolloutWorker instance, in which this episode runs.
|
||||||
|
"""
|
||||||
self.new_batch_builder: Callable[
|
self.new_batch_builder: Callable[
|
||||||
[], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
|
[], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
|
||||||
self.add_extra_batch: Callable[[SampleBatchType],
|
self.add_extra_batch: Callable[[SampleBatchType],
|
||||||
|
@ -80,9 +93,8 @@ class MultiAgentEpisode:
|
||||||
self.media: Dict[str, Any] = {}
|
self.media: Dict[str, Any] = {}
|
||||||
self.policy_map: PolicyMap = policies
|
self.policy_map: PolicyMap = policies
|
||||||
self._policies = self.policy_map # backward compatibility
|
self._policies = self.policy_map # backward compatibility
|
||||||
self.policy_mapping_fn: Callable[[
|
self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
|
||||||
AgentID, "MultiAgentEpisode", "RolloutWorker"
|
PolicyID] = policy_mapping_fn
|
||||||
], PolicyID] = policy_mapping_fn
|
|
||||||
self._next_agent_index: int = 0
|
self._next_agent_index: int = 0
|
||||||
self._agent_to_index: Dict[AgentID, int] = {}
|
self._agent_to_index: Dict[AgentID, int] = {}
|
||||||
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
|
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
|
||||||
|
@ -92,21 +104,11 @@ class MultiAgentEpisode:
|
||||||
self._agent_to_last_done: Dict[AgentID, bool] = {}
|
self._agent_to_last_done: Dict[AgentID, bool] = {}
|
||||||
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
|
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
|
||||||
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
|
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
|
||||||
self._agent_to_last_pi_info: Dict[AgentID, dict] = {}
|
self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
|
||||||
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
|
self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
|
||||||
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
|
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
|
||||||
list)
|
list)
|
||||||
|
|
||||||
# TODO: Deprecated.
|
|
||||||
@property
|
|
||||||
def _policy_mapping_fn(self):
|
|
||||||
deprecation_warning(
|
|
||||||
old="MultiAgentEpisode._policy_mapping_fn",
|
|
||||||
new="MultiAgentEpisode.policy_mapping_fn",
|
|
||||||
error=False,
|
|
||||||
)
|
|
||||||
return self.policy_mapping_fn
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def soft_reset(self) -> None:
|
def soft_reset(self) -> None:
|
||||||
"""Clears rewards and metrics, but retains RNN and other state.
|
"""Clears rewards and metrics, but retains RNN and other state.
|
||||||
|
@ -126,15 +128,17 @@ class MultiAgentEpisode:
|
||||||
|
|
||||||
If the agent is new, the policy mapping fn will be called to bind the
|
If the agent is new, the policy mapping fn will be called to bind the
|
||||||
agent to a policy for the duration of the entire episode (even if the
|
agent to a policy for the duration of the entire episode (even if the
|
||||||
policy_mapping_fn is changed in the meantime).
|
policy_mapping_fn is changed in the meantime!).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id (AgentID): The agent ID to lookup the policy ID for.
|
agent_id: The agent ID to lookup the policy ID for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PolicyID: The policy ID for the specified agent.
|
The policy ID for the specified agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Perform a new policy_mapping_fn lookup and bind AgentID for the
|
||||||
|
# duration of this episode to the returned PolicyID.
|
||||||
if agent_id not in self._agent_to_policy:
|
if agent_id not in self._agent_to_policy:
|
||||||
# Try new API: pass in agent_id and episode as named args.
|
# Try new API: pass in agent_id and episode as named args.
|
||||||
# New signature should be: (agent_id, episode, worker, **kwargs)
|
# New signature should be: (agent_id, episode, worker, **kwargs)
|
||||||
|
@ -153,8 +157,11 @@ class MultiAgentEpisode:
|
||||||
self.policy_mapping_fn(agent_id)
|
self.policy_mapping_fn(agent_id)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
# Use already determined PolicyID.
|
||||||
else:
|
else:
|
||||||
policy_id = self._agent_to_policy[agent_id]
|
policy_id = self._agent_to_policy[agent_id]
|
||||||
|
|
||||||
|
# PolicyID not found in policy map -> Error.
|
||||||
if policy_id not in self.policy_map:
|
if policy_id not in self.policy_map:
|
||||||
raise KeyError("policy_mapping_fn returned invalid policy id "
|
raise KeyError("policy_mapping_fn returned invalid policy id "
|
||||||
f"'{policy_id}'!")
|
f"'{policy_id}'!")
|
||||||
|
@ -162,33 +169,70 @@ class MultiAgentEpisode:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_observation_for(
|
def last_observation_for(
|
||||||
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType:
|
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
|
||||||
"""Returns the last observation for the specified agent."""
|
"""Returns the last observation for the specified AgentID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last observation for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last observation the specified AgentID has seen. None in case
|
||||||
|
the agent has never made any observations in the episode.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._agent_to_last_obs.get(agent_id)
|
return self._agent_to_last_obs.get(agent_id)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_raw_obs_for(self,
|
def last_raw_obs_for(
|
||||||
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvObsType:
|
self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
|
||||||
"""Returns the last un-preprocessed obs for the specified agent."""
|
"""Returns the last un-preprocessed obs for the specified AgentID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last un-preprocessed
|
||||||
|
observation for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last un-preprocessed observation the specified AgentID has seen.
|
||||||
|
None in case the agent has never made any observations in the
|
||||||
|
episode.
|
||||||
|
"""
|
||||||
return self._agent_to_last_raw_obs.get(agent_id)
|
return self._agent_to_last_raw_obs.get(agent_id)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_info_for(self,
|
def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
|
||||||
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvInfoDict:
|
) -> Optional[EnvInfoDict]:
|
||||||
"""Returns the last info for the specified agent."""
|
"""Returns the last info for the specified AgentID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last info for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last info dict the specified AgentID has seen.
|
||||||
|
None in case the agent has never made any observations in the
|
||||||
|
episode.
|
||||||
|
"""
|
||||||
return self._agent_to_last_info.get(agent_id)
|
return self._agent_to_last_info.get(agent_id)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_action_for(self,
|
def last_action_for(self,
|
||||||
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
|
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
|
||||||
"""Returns the last action for the specified agent, or zeros."""
|
"""Returns the last action for the specified AgentID, or zeros.
|
||||||
|
|
||||||
|
The "last" action is the most recent one taken by the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last action for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last action the specified AgentID has executed.
|
||||||
|
Zeros in case the agent has never performed any actions in the
|
||||||
|
episode.
|
||||||
|
"""
|
||||||
|
# Agent has already taken at least one action in the episode.
|
||||||
if agent_id in self._agent_to_last_action:
|
if agent_id in self._agent_to_last_action:
|
||||||
return flatten_to_single_ndarray(
|
return flatten_to_single_ndarray(
|
||||||
self._agent_to_last_action[agent_id])
|
self._agent_to_last_action[agent_id])
|
||||||
|
# Agent has not acted yet, return all zeros.
|
||||||
else:
|
else:
|
||||||
policy_id = self.policy_for(agent_id)
|
policy_id = self.policy_for(agent_id)
|
||||||
policy = self.policy_map[policy_id]
|
policy = self.policy_map[policy_id]
|
||||||
|
@ -200,29 +244,84 @@ class MultiAgentEpisode:
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def prev_action_for(self,
|
def prev_action_for(self,
|
||||||
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
|
agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
|
||||||
"""Returns the previous action for the specified agent."""
|
"""Returns the previous action for the specified agent, or zeros.
|
||||||
|
|
||||||
|
The "previous" action is the one taken one timestep before the
|
||||||
|
most recent action taken by the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the previous action for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Previous action the specified AgentID has executed.
|
||||||
|
Zero in case the agent has never performed any actions (or only
|
||||||
|
one) in the episode.
|
||||||
|
"""
|
||||||
|
# We are at t > 1 -> There has been a previous action by this agent.
|
||||||
if agent_id in self._agent_to_prev_action:
|
if agent_id in self._agent_to_prev_action:
|
||||||
return flatten_to_single_ndarray(
|
return flatten_to_single_ndarray(
|
||||||
self._agent_to_prev_action[agent_id])
|
self._agent_to_prev_action[agent_id])
|
||||||
|
# We're at t <= 1, so return all zeros.
|
||||||
else:
|
else:
|
||||||
# We're at t=0, so return all zeros.
|
|
||||||
return np.zeros_like(self.last_action_for(agent_id))
|
return np.zeros_like(self.last_action_for(agent_id))
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
|
def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
|
||||||
"""Returns the previous reward for the specified agent."""
|
"""Returns the last reward for the specified agent, or zero.
|
||||||
|
|
||||||
|
The "last" reward is the one received most recently by the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last reward for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last reward for the the specified AgentID.
|
||||||
|
Zero in case the agent has never performed any actions
|
||||||
|
(and thus received rewards) in the episode.
|
||||||
|
"""
|
||||||
|
|
||||||
history = self._agent_reward_history[agent_id]
|
history = self._agent_reward_history[agent_id]
|
||||||
|
# We are at t > 0 -> Return previously received reward.
|
||||||
|
if len(history) >= 1:
|
||||||
|
return history[-1]
|
||||||
|
# We're at t=0, so there is no previous reward, just return zero.
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
|
||||||
|
"""Returns the previous reward for the specified agent, or zero.
|
||||||
|
|
||||||
|
The "previous" reward is the one received one timestep before the
|
||||||
|
most recently received reward of the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the previous reward for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Previous reward for the the specified AgentID.
|
||||||
|
Zero in case the agent has never performed any actions (or only
|
||||||
|
one) in the episode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
history = self._agent_reward_history[agent_id]
|
||||||
|
# We are at t > 1 -> Return reward prior to most recent (last) one.
|
||||||
if len(history) >= 2:
|
if len(history) >= 2:
|
||||||
return history[-2]
|
return history[-2]
|
||||||
|
# We're at t <= 1, so there is no previous reward, just return zero.
|
||||||
else:
|
else:
|
||||||
# We're at t=0, so there is no previous reward, just return zero.
|
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
|
def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
|
||||||
"""Returns the last RNN state for the specified agent."""
|
"""Returns the last RNN state for the specified agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the most recent RNN state for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Most recent RNN state of the the specified AgentID.
|
||||||
|
"""
|
||||||
|
|
||||||
if agent_id not in self._agent_to_rnn_state:
|
if agent_id not in self._agent_to_rnn_state:
|
||||||
policy_id = self.policy_for(agent_id)
|
policy_id = self.policy_for(agent_id)
|
||||||
|
@ -232,24 +331,44 @@ class MultiAgentEpisode:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
|
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
|
||||||
"""Returns the last done flag received for the specified agent."""
|
"""Returns the last done flag for the specified AgentID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last done flag for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Last done flag for the specified AgentID.
|
||||||
|
"""
|
||||||
if agent_id not in self._agent_to_last_done:
|
if agent_id not in self._agent_to_last_done:
|
||||||
self._agent_to_last_done[agent_id] = False
|
self._agent_to_last_done[agent_id] = False
|
||||||
return self._agent_to_last_done[agent_id]
|
return self._agent_to_last_done[agent_id]
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def last_pi_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> dict:
|
def last_extra_action_outs_for(
|
||||||
"""Returns the last info object for the specified agent."""
|
self,
|
||||||
|
agent_id: AgentID = _DUMMY_AGENT_ID,
|
||||||
|
) -> dict:
|
||||||
|
"""Returns the last extra-action outputs for the specified agent.
|
||||||
|
|
||||||
return self._agent_to_last_pi_info[agent_id]
|
This data is returned by a call to
|
||||||
|
`Policy.compute_actions_from_input_dict` as the 3rd return value
|
||||||
|
(1st return value = action; 2nd return value = RNN state outs).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's ID to get the last extra-action outs for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The last extra-action outs for the specified AgentID.
|
||||||
|
"""
|
||||||
|
return self._agent_to_last_extra_action_outs[agent_id]
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_agents(self) -> List[AgentID]:
|
def get_agents(self) -> List[AgentID]:
|
||||||
"""Returns list of agent IDs that have appeared in this episode.
|
"""Returns list of agent IDs that have appeared in this episode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[AgentID]: The list of all agents that have appeared so
|
The list of all agent IDs that have appeared so far in this
|
||||||
far in this episode.
|
episode.
|
||||||
"""
|
"""
|
||||||
return list(self._agent_to_index.keys())
|
return list(self._agent_to_index.keys())
|
||||||
|
|
||||||
|
@ -282,11 +401,31 @@ class MultiAgentEpisode:
|
||||||
self._agent_to_last_action[agent_id]
|
self._agent_to_last_action[agent_id]
|
||||||
self._agent_to_last_action[agent_id] = action
|
self._agent_to_last_action[agent_id] = action
|
||||||
|
|
||||||
def _set_last_pi_info(self, agent_id, pi_info):
|
def _set_last_extra_action_outs(self, agent_id, pi_info):
|
||||||
self._agent_to_last_pi_info[agent_id] = pi_info
|
self._agent_to_last_extra_action_outs[agent_id] = pi_info
|
||||||
|
|
||||||
def _agent_index(self, agent_id):
|
def _agent_index(self, agent_id):
|
||||||
if agent_id not in self._agent_to_index:
|
if agent_id not in self._agent_to_index:
|
||||||
self._agent_to_index[agent_id] = self._next_agent_index
|
self._agent_to_index[agent_id] = self._next_agent_index
|
||||||
self._next_agent_index += 1
|
self._next_agent_index += 1
|
||||||
return self._agent_to_index[agent_id]
|
return self._agent_to_index[agent_id]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _policy_mapping_fn(self):
|
||||||
|
deprecation_warning(
|
||||||
|
old="Episode._policy_mapping_fn",
|
||||||
|
new="Episode.policy_mapping_fn",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
return self.policy_mapping_fn
|
||||||
|
|
||||||
|
@Deprecated(new="Episode.last_extra_action_outs_for", error=False)
|
||||||
|
def last_pi_info_for(self, *args, **kwargs):
|
||||||
|
return self.last_extra_action_outs_for(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility. The name Episode implies that there is
|
||||||
|
# also a (single agent?) Episode.
|
||||||
|
@Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
|
||||||
|
class MultiAgentEpisode(Episode):
|
||||||
|
pass
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import ObjectRef
|
from ray import ObjectRef
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
|
||||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
|
||||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||||
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||||
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
|
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
|
||||||
|
@ -18,6 +17,17 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
||||||
|
"episode_length",
|
||||||
|
"episode_reward",
|
||||||
|
"agent_rewards",
|
||||||
|
"custom_metrics",
|
||||||
|
"perf_stats",
|
||||||
|
"hist_data",
|
||||||
|
"media",
|
||||||
|
])
|
||||||
|
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})
|
||||||
|
|
||||||
|
|
||||||
def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
|
def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
|
||||||
if key in stats:
|
if key in stats:
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict
|
||||||
|
|
||||||
from ray.rllib.env import BaseEnv
|
from ray.rllib.env import BaseEnv
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
from ray.rllib.evaluation import Episode, RolloutWorker
|
||||||
from ray.rllib.utils.framework import TensorType
|
from ray.rllib.utils.framework import TensorType
|
||||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
from ray.rllib.utils.typing import AgentID, PolicyID
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class ObservationFunction:
|
||||||
|
|
||||||
def __call__(self, agent_obs: Dict[AgentID, TensorType],
|
def __call__(self, agent_obs: Dict[AgentID, TensorType],
|
||||||
worker: RolloutWorker, base_env: BaseEnv,
|
worker: RolloutWorker, base_env: BaseEnv,
|
||||||
policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode,
|
policies: Dict[PolicyID, Policy], episode: Episode,
|
||||||
**kw) -> Dict[AgentID, TensorType]:
|
**kw) -> Dict[AgentID, TensorType]:
|
||||||
"""Callback run on each environment step to observe the environment.
|
"""Callback run on each environment step to observe the environment.
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class ObservationFunction:
|
||||||
retrieved by calling `base_env.get_sub_environments()`.
|
retrieved by calling `base_env.get_sub_environments()`.
|
||||||
policies (dict): Mapping of policy id to policy objects. In single
|
policies (dict): Mapping of policy id to policy objects. In single
|
||||||
agent mode there will only be a single "default" policy.
|
agent mode there will only be a single "default" policy.
|
||||||
episode (MultiAgentEpisode): Episode state object.
|
episode (Episode): Episode state object.
|
||||||
kwargs: Forward compatibility placeholder.
|
kwargs: Forward compatibility placeholder.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
@ -135,7 +135,7 @@ def compute_gae_for_sample_batch(
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
episode: Optional[Episode] = None) -> SampleBatch:
|
||||||
"""Adds GAE (generalized advantage estimations) to a trajectory.
|
"""Adds GAE (generalized advantage estimations) to a trajectory.
|
||||||
|
|
||||||
The trajectory contains only data from one episode and from one agent.
|
The trajectory contains only data from one episode and from one agent.
|
||||||
|
@ -153,7 +153,7 @@ def compute_gae_for_sample_batch(
|
||||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||||
same episode). NOTE: The other agents use the same policy.
|
same episode). NOTE: The other agents use the same policy.
|
||||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
episode (Optional[Episode]): Optional multi-agent episode
|
||||||
object in which the agents operated.
|
object in which the agents operated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
import collections
|
|
||||||
|
|
||||||
# Define this in its own file, see #5125
|
|
||||||
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
|
||||||
"episode_length",
|
|
||||||
"episode_reward",
|
|
||||||
"agent_rewards",
|
|
||||||
"custom_metrics",
|
|
||||||
"perf_stats",
|
|
||||||
"hist_data",
|
|
||||||
"media",
|
|
||||||
])
|
|
||||||
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})
|
|
File diff suppressed because it is too large
Load diff
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
from typing import List, Any, Dict, Optional, TYPE_CHECKING
|
from typing import List, Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
|
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
|
||||||
|
@ -152,15 +152,15 @@ class MultiAgentSampleBatchBuilder:
|
||||||
|
|
||||||
self.agent_builders[agent_id].add_values(**values)
|
self.agent_builders[agent_id].add_values(**values)
|
||||||
|
|
||||||
def postprocess_batch_so_far(
|
def postprocess_batch_so_far(self,
|
||||||
self, episode: Optional[MultiAgentEpisode] = None) -> None:
|
episode: Optional[Episode] = None) -> None:
|
||||||
"""Apply policy postprocessors to any unprocessed rows.
|
"""Apply policy postprocessors to any unprocessed rows.
|
||||||
|
|
||||||
This pushes the postprocessed per-agent batches onto the per-policy
|
This pushes the postprocessed per-agent batches onto the per-policy
|
||||||
builders, clearing per-agent state.
|
builders, clearing per-agent state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode (Optional[MultiAgentEpisode]): The Episode object that
|
episode (Optional[Episode]): The Episode object that
|
||||||
holds this MultiAgentBatchBuilder object.
|
holds this MultiAgentBatchBuilder object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -234,15 +234,15 @@ class MultiAgentSampleBatchBuilder:
|
||||||
"Alternatively, set no_done_at_end=True to allow this.")
|
"Alternatively, set no_done_at_end=True to allow this.")
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def build_and_reset(self, episode: Optional[MultiAgentEpisode] = None
|
def build_and_reset(self,
|
||||||
) -> MultiAgentBatch:
|
episode: Optional[Episode] = None) -> MultiAgentBatch:
|
||||||
"""Returns the accumulated sample batches for each policy.
|
"""Returns the accumulated sample batches for each policy.
|
||||||
|
|
||||||
Any unprocessed rows will be first postprocessed with a policy
|
Any unprocessed rows will be first postprocessed with a policy
|
||||||
postprocessor. The internal state of this builder will be reset.
|
postprocessor. The internal state of this builder will be reset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode (Optional[MultiAgentEpisode]): The Episode object that
|
episode (Optional[Episode]): The Episode object that
|
||||||
holds this MultiAgentBatchBuilder object or None.
|
holds this MultiAgentBatchBuilder object or None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -14,8 +14,8 @@ from ray.rllib.evaluation.collectors.sample_collector import \
|
||||||
SampleCollector
|
SampleCollector
|
||||||
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
||||||
SimpleListCollector
|
SimpleListCollector
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||||
from ray.rllib.evaluation.sample_batch_builder import \
|
from ray.rllib.evaluation.sample_batch_builder import \
|
||||||
MultiAgentSampleBatchBuilder
|
MultiAgentSampleBatchBuilder
|
||||||
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
||||||
|
@ -110,16 +110,45 @@ class SamplerInput(InputReader, metaclass=ABCMeta):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_data(self) -> SampleBatchType:
|
def get_data(self) -> SampleBatchType:
|
||||||
|
"""Called by `self.next()` to return the next batch of data.
|
||||||
|
|
||||||
|
Override this in child classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next batch of data.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_metrics(self) -> List[RolloutMetrics]:
|
def get_metrics(self) -> List[RolloutMetrics]:
|
||||||
|
"""Returns list of episode metrics since the last call to this method.
|
||||||
|
|
||||||
|
The list will contain one RolloutMetrics object per completed episode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of RolloutMetrics objects, one per completed episode since
|
||||||
|
the last call to this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_extra_batches(self) -> List[SampleBatchType]:
|
def get_extra_batches(self) -> List[SampleBatchType]:
|
||||||
|
"""Returns list of extra batches since the last call to this method.
|
||||||
|
|
||||||
|
The list will contain all SampleBatches or
|
||||||
|
MultiAgentBatches that the user has provided thus-far. Users can
|
||||||
|
add these "extra batches" to an episode by calling the episode's
|
||||||
|
`add_extra_batch([SampleBatchType])` method. This can be done from
|
||||||
|
inside an overridden `Policy.compute_actions_from_input_dict(...,
|
||||||
|
episodes)` or from a custom callback's `on_episode_[start|step|end]()`
|
||||||
|
methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SamplesBatches or MultiAgentBatches provided thus-far by
|
||||||
|
the user since the last call to this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,7 +172,7 @@ class SyncSampler(SamplerInput):
|
||||||
clip_actions: bool = False,
|
clip_actions: bool = False,
|
||||||
soft_horizon: bool = False,
|
soft_horizon: bool = False,
|
||||||
no_done_at_end: bool = False,
|
no_done_at_end: bool = False,
|
||||||
observation_fn: "ObservationFunction" = None,
|
observation_fn: Optional["ObservationFunction"] = None,
|
||||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||||
render: bool = False,
|
render: bool = False,
|
||||||
# Obsolete.
|
# Obsolete.
|
||||||
|
@ -153,41 +182,44 @@ class SyncSampler(SamplerInput):
|
||||||
obs_filters=None,
|
obs_filters=None,
|
||||||
tf_sess=None,
|
tf_sess=None,
|
||||||
):
|
):
|
||||||
"""Initializes a SyncSampler object.
|
"""Initializes a SyncSampler instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): The RolloutWorker that will use this
|
worker: The RolloutWorker that will use this Sampler for sampling.
|
||||||
Sampler for sampling.
|
env: Any Env object. Will be converted into an RLlib BaseEnv.
|
||||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
clip_rewards: True for +/-1.0 clipping,
|
||||||
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
|
|
||||||
actual float value for +/- value clipping. False for no
|
actual float value for +/- value clipping. False for no
|
||||||
clipping.
|
clipping.
|
||||||
rollout_fragment_length (int): The length of a fragment to collect
|
rollout_fragment_length: The length of a fragment to collect
|
||||||
before building a SampleBatch from the data and resetting
|
before building a SampleBatch from the data and resetting
|
||||||
the SampleBatchBuilder object.
|
the SampleBatchBuilder object.
|
||||||
callbacks (Callbacks): The Callbacks object to use when episode
|
count_steps_by: One of "env_steps" (default) or "agent_steps".
|
||||||
|
Use "agent_steps", if you want rollout lengths to be counted
|
||||||
|
by individual agent steps. In a multi-agent env,
|
||||||
|
a single env_step contains one or more agent_steps, depending
|
||||||
|
on how many agents are present at any given time in the
|
||||||
|
ongoing episode.
|
||||||
|
callbacks: The Callbacks object to use when episode
|
||||||
events happen during rollout.
|
events happen during rollout.
|
||||||
horizon (Optional[int]): Hard-reset the Env
|
horizon: Hard-reset the Env after this many timesteps.
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch: Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be
|
episodes into each batch. This guarantees batches will be
|
||||||
exactly `rollout_fragment_length` in size.
|
exactly `rollout_fragment_length` in size.
|
||||||
normalize_actions (bool): Whether to normalize actions to the
|
normalize_actions: Whether to normalize actions to the
|
||||||
action space's bounds.
|
action space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions according to the
|
clip_actions: Whether to clip actions according to the
|
||||||
given action_space's bounds.
|
given action_space's bounds.
|
||||||
soft_horizon (bool): If True, calculate bootstrapped values as if
|
soft_horizon: If True, calculate bootstrapped values as if
|
||||||
episode had ended, but don't physically reset the environment
|
episode had ended, but don't physically reset the environment
|
||||||
when the horizon is hit.
|
when the horizon is hit.
|
||||||
no_done_at_end (bool): Ignore the done=True at the end of the
|
no_done_at_end: Ignore the done=True at the end of the
|
||||||
episode and instead record done=False.
|
episode and instead record done=False.
|
||||||
observation_fn (Optional[ObservationFunction]): Optional
|
observation_fn: Optional multi-agent observation func to use for
|
||||||
multi-agent observation func to use for preprocessing
|
preprocessing observations.
|
||||||
observations.
|
sample_collector_class: An optional Samplecollector sub-class to
|
||||||
sample_collector_class (Optional[Type[SampleCollector]]): An
|
use to collect, store, and retrieve environment-, model-,
|
||||||
optional Samplecollector sub-class to use to collect, store,
|
and sampler data.
|
||||||
and retrieve environment-, model-, and sampler data.
|
render: Whether to try to render the environment after each step.
|
||||||
render (bool): Whether to try to render the environment after each
|
|
||||||
step.
|
|
||||||
"""
|
"""
|
||||||
# All of the following arguments are deprecated. They will instead be
|
# All of the following arguments are deprecated. They will instead be
|
||||||
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
||||||
|
@ -262,8 +294,9 @@ class SyncSampler(SamplerInput):
|
||||||
class AsyncSampler(threading.Thread, SamplerInput):
|
class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
"""Async SamplerInput that collects experiences in thread and queues them.
|
"""Async SamplerInput that collects experiences in thread and queues them.
|
||||||
|
|
||||||
Once started, experiences are continuously collected and put into a Queue,
|
Once started, experiences are continuously collected in the background
|
||||||
from where they can be unqueued by the caller of `get_data()`.
|
and put into a Queue, from where they can be unqueued by the caller
|
||||||
|
of `get_data()`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -279,12 +312,12 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
multiple_episodes_in_batch: bool = False,
|
multiple_episodes_in_batch: bool = False,
|
||||||
normalize_actions: bool = True,
|
normalize_actions: bool = True,
|
||||||
clip_actions: bool = False,
|
clip_actions: bool = False,
|
||||||
blackhole_outputs: bool = False,
|
|
||||||
soft_horizon: bool = False,
|
soft_horizon: bool = False,
|
||||||
no_done_at_end: bool = False,
|
no_done_at_end: bool = False,
|
||||||
observation_fn: Optional["ObservationFunction"] = None,
|
observation_fn: Optional["ObservationFunction"] = None,
|
||||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||||
render: bool = False,
|
render: bool = False,
|
||||||
|
blackhole_outputs: bool = False,
|
||||||
# Obsolete.
|
# Obsolete.
|
||||||
policies=None,
|
policies=None,
|
||||||
policy_mapping_fn=None,
|
policy_mapping_fn=None,
|
||||||
|
@ -292,45 +325,44 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
obs_filters=None,
|
obs_filters=None,
|
||||||
tf_sess=None,
|
tf_sess=None,
|
||||||
):
|
):
|
||||||
"""Initializes a AsyncSampler object.
|
"""Initializes an AsyncSampler instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): The RolloutWorker that will use this
|
worker: The RolloutWorker that will use this Sampler for sampling.
|
||||||
Sampler for sampling.
|
env: Any Env object. Will be converted into an RLlib BaseEnv.
|
||||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
clip_rewards: True for +/-1.0 clipping,
|
||||||
clip_rewards (Union[bool, float]): True for +/-1.0 clipping,
|
|
||||||
actual float value for +/- value clipping. False for no
|
actual float value for +/- value clipping. False for no
|
||||||
clipping.
|
clipping.
|
||||||
rollout_fragment_length (int): The length of a fragment to collect
|
rollout_fragment_length: The length of a fragment to collect
|
||||||
before building a SampleBatch from the data and resetting
|
before building a SampleBatch from the data and resetting
|
||||||
the SampleBatchBuilder object.
|
the SampleBatchBuilder object.
|
||||||
count_steps_by (str): Either "env_steps" or "agent_steps".
|
count_steps_by: One of "env_steps" (default) or "agent_steps".
|
||||||
Refers to the unit of `rollout_fragment_length`.
|
Use "agent_steps", if you want rollout lengths to be counted
|
||||||
callbacks (Callbacks): The Callbacks object to use when episode
|
by individual agent steps. In a multi-agent env,
|
||||||
events happen during rollout.
|
a single env_step contains one or more agent_steps, depending
|
||||||
|
on how many agents are present at any given time in the
|
||||||
|
ongoing episode.
|
||||||
horizon: Hard-reset the Env after this many timesteps.
|
horizon: Hard-reset the Env after this many timesteps.
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch: Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be
|
episodes into each batch. This guarantees batches will be
|
||||||
exactly `rollout_fragment_length` in size.
|
exactly `rollout_fragment_length` in size.
|
||||||
normalize_actions (bool): Whether to normalize actions to the
|
normalize_actions: Whether to normalize actions to the
|
||||||
action space's bounds.
|
action space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions according to the
|
clip_actions: Whether to clip actions according to the
|
||||||
given action_space's bounds.
|
given action_space's bounds.
|
||||||
blackhole_outputs (bool): Whether to collect samples, but then
|
blackhole_outputs: Whether to collect samples, but then
|
||||||
not further process or store them (throw away all samples).
|
not further process or store them (throw away all samples).
|
||||||
soft_horizon (bool): If True, calculate bootstrapped values as if
|
soft_horizon: If True, calculate bootstrapped values as if
|
||||||
episode had ended, but don't physically reset the environment
|
episode had ended, but don't physically reset the environment
|
||||||
when the horizon is hit.
|
when the horizon is hit.
|
||||||
no_done_at_end (bool): Ignore the done=True at the end of the
|
no_done_at_end: Ignore the done=True at the end of the
|
||||||
episode and instead record done=False.
|
episode and instead record done=False.
|
||||||
observation_fn (Optional[ObservationFunction]): Optional
|
observation_fn: Optional multi-agent observation func to use for
|
||||||
multi-agent observation func to use for preprocessing
|
preprocessing observations.
|
||||||
observations.
|
sample_collector_class: An optional SampleCollector sub-class to
|
||||||
sample_collector_class (Optional[Type[SampleCollector]]): An
|
use to collect, store, and retrieve environment-, model-,
|
||||||
optional Samplecollector sub-class to use to collect, store,
|
and sampler data.
|
||||||
and retrieve environment-, model-, and sampler data.
|
render: Whether to try to render the environment after each step.
|
||||||
render (bool): Whether to try to render the environment after each
|
|
||||||
step.
|
|
||||||
"""
|
"""
|
||||||
# All of the following arguments are deprecated. They will instead be
|
# All of the following arguments are deprecated. They will instead be
|
||||||
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
||||||
|
@ -467,32 +499,32 @@ def _env_runner(
|
||||||
"""This implements the common experience collection logic.
|
"""This implements the common experience collection logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): Reference to the current rollout worker.
|
worker: Reference to the current rollout worker.
|
||||||
base_env (BaseEnv): Env implementing BaseEnv.
|
base_env: Env implementing BaseEnv.
|
||||||
extra_batch_callback (fn): function to send extra batch data to.
|
extra_batch_callback: function to send extra batch data to.
|
||||||
horizon: Horizon of the episode.
|
horizon: Horizon of the episode.
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch: Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be exactly
|
episodes into each batch. This guarantees batches will be exactly
|
||||||
`rollout_fragment_length` in size.
|
`rollout_fragment_length` in size.
|
||||||
normalize_actions (bool): Whether to normalize actions to the action
|
normalize_actions: Whether to normalize actions to the action
|
||||||
space's bounds.
|
space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions to the space range.
|
clip_actions: Whether to clip actions to the space range.
|
||||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
callbacks: User callbacks to run on episode events.
|
||||||
perf_stats (_PerfStats): Record perf stats into this object.
|
perf_stats: Record perf stats into this object.
|
||||||
soft_horizon (bool): Calculate rewards but don't reset the
|
soft_horizon: Calculate rewards but don't reset the
|
||||||
environment when the horizon is hit.
|
environment when the horizon is hit.
|
||||||
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
no_done_at_end: Ignore the done=True at the end of the episode
|
||||||
and instead record done=False.
|
and instead record done=False.
|
||||||
observation_fn (ObservationFunction): Optional multi-agent
|
observation_fn: Optional multi-agent
|
||||||
observation func to use for preprocessing observations.
|
observation func to use for preprocessing observations.
|
||||||
sample_collector (Optional[SampleCollector]): An optional
|
sample_collector: An optional
|
||||||
SampleCollector object to use.
|
SampleCollector object to use.
|
||||||
render (bool): Whether to try to render the environment after each
|
render: Whether to try to render the environment after each
|
||||||
step.
|
step.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
rollout (SampleBatch): Object containing state, action, reward,
|
Object containing state, action, reward, terminal condition,
|
||||||
terminal condition, and other fields as dictated by `policy`.
|
and other fields as dictated by `policy`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# May be populated with used for image rendering
|
# May be populated with used for image rendering
|
||||||
|
@ -546,7 +578,7 @@ def _env_runner(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def new_episode(env_id):
|
def new_episode(env_id):
|
||||||
episode = MultiAgentEpisode(
|
episode = Episode(
|
||||||
worker.policy_map,
|
worker.policy_map,
|
||||||
worker.policy_mapping_fn,
|
worker.policy_mapping_fn,
|
||||||
get_batch_builder,
|
get_batch_builder,
|
||||||
|
@ -576,7 +608,7 @@ def _env_runner(
|
||||||
)
|
)
|
||||||
return episode
|
return episode
|
||||||
|
|
||||||
active_episodes: Dict[EnvID, MultiAgentEpisode] = \
|
active_episodes: Dict[EnvID, Episode] = \
|
||||||
NewEpisodeDefaultDict(new_episode)
|
NewEpisodeDefaultDict(new_episode)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -684,7 +716,7 @@ def _process_observations(
|
||||||
*,
|
*,
|
||||||
worker: "RolloutWorker",
|
worker: "RolloutWorker",
|
||||||
base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
active_episodes: Dict[EnvID, MultiAgentEpisode],
|
active_episodes: Dict[EnvID, Episode],
|
||||||
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
||||||
rewards: Dict[EnvID, Dict[AgentID, float]],
|
rewards: Dict[EnvID, Dict[AgentID, float]],
|
||||||
dones: Dict[EnvID, Dict[AgentID, bool]],
|
dones: Dict[EnvID, Dict[AgentID, bool]],
|
||||||
|
@ -701,38 +733,37 @@ def _process_observations(
|
||||||
"""Record new data from the environment and prepare for policy evaluation.
|
"""Record new data from the environment and prepare for policy evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): Reference to the current rollout worker.
|
worker: Reference to the current rollout worker.
|
||||||
base_env (BaseEnv): Env implementing BaseEnv.
|
base_env: Env implementing BaseEnv.
|
||||||
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from
|
active_episodes: Mapping from
|
||||||
episode ID to currently ongoing MultiAgentEpisode object.
|
episode ID to currently ongoing Episode object.
|
||||||
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
|
unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
|
||||||
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
|
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
|
||||||
call.
|
call.
|
||||||
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
|
rewards: Doubly keyed dict of env-ids -> agent ids ->
|
||||||
rewards tensor, returned by a `BaseEnv.poll()` call.
|
rewards tensor, returned by a `BaseEnv.poll()` call.
|
||||||
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
|
dones: Doubly keyed dict of env-ids -> agent ids ->
|
||||||
boolean done flags, returned by a `BaseEnv.poll()` call.
|
boolean done flags, returned by a `BaseEnv.poll()` call.
|
||||||
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
|
infos: Doubly keyed dict of env-ids -> agent ids ->
|
||||||
info dicts, returned by a `BaseEnv.poll()` call.
|
info dicts, returned by a `BaseEnv.poll()` call.
|
||||||
horizon (int): Horizon of the episode.
|
horizon: Horizon of the episode.
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch: Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be exactly
|
episodes into each batch. This guarantees batches will be exactly
|
||||||
`rollout_fragment_length` in size.
|
`rollout_fragment_length` in size.
|
||||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
callbacks: User callbacks to run on episode events.
|
||||||
soft_horizon (bool): Calculate rewards but don't reset the
|
soft_horizon: Calculate rewards but don't reset the
|
||||||
environment when the horizon is hit.
|
environment when the horizon is hit.
|
||||||
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
no_done_at_end: Ignore the done=True at the end of the episode
|
||||||
and instead record done=False.
|
and instead record done=False.
|
||||||
observation_fn (ObservationFunction): Optional multi-agent
|
observation_fn: Optional multi-agent
|
||||||
observation func to use for preprocessing observations.
|
observation func to use for preprocessing observations.
|
||||||
sample_collector (SampleCollector): The SampleCollector object
|
sample_collector: The SampleCollector object
|
||||||
used to store and retrieve environment samples.
|
used to store and retrieve environment samples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple:
|
Tuple consisting of 1) active_envs: Set of non-terminated env ids.
|
||||||
- active_envs: Set of non-terminated env ids.
|
2) to_eval: Map of policy_id to list of agent PolicyEvalData.
|
||||||
- to_eval: Map of policy_id to list of agent PolicyEvalData.
|
3) outputs: List of metrics and samples to return from the sampler.
|
||||||
- outputs: List of metrics and samples to return from the sampler.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Output objects.
|
# Output objects.
|
||||||
|
@ -744,7 +775,7 @@ def _process_observations(
|
||||||
# types: EnvID, Dict[AgentID, EnvObsType]
|
# types: EnvID, Dict[AgentID, EnvObsType]
|
||||||
for env_id, all_agents_obs in unfiltered_obs.items():
|
for env_id, all_agents_obs in unfiltered_obs.items():
|
||||||
is_new_episode: bool = env_id not in active_episodes
|
is_new_episode: bool = env_id not in active_episodes
|
||||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
episode: Episode = active_episodes[env_id]
|
||||||
|
|
||||||
if not is_new_episode:
|
if not is_new_episode:
|
||||||
sample_collector.episode_step(episode)
|
sample_collector.episode_step(episode)
|
||||||
|
@ -854,9 +885,11 @@ def _process_observations(
|
||||||
# Next observation.
|
# Next observation.
|
||||||
SampleBatch.NEXT_OBS: filtered_obs,
|
SampleBatch.NEXT_OBS: filtered_obs,
|
||||||
}
|
}
|
||||||
# Add extra-action-fetches to collectors.
|
# Add extra-action-fetches (policy-inference infos) to
|
||||||
|
# collectors.
|
||||||
pol = worker.policy_map[policy_id]
|
pol = worker.policy_map[policy_id]
|
||||||
for key, value in episode.last_pi_info_for(agent_id).items():
|
for key, value in episode.last_extra_action_outs_for(
|
||||||
|
agent_id).items():
|
||||||
if key in pol.view_requirements:
|
if key in pol.view_requirements:
|
||||||
values_dict[key] = value
|
values_dict[key] = value
|
||||||
# Env infos for this agent.
|
# Env infos for this agent.
|
||||||
|
@ -947,7 +980,7 @@ def _process_observations(
|
||||||
# Creates a new episode if this is not async return.
|
# Creates a new episode if this is not async return.
|
||||||
# If reset is async, we will get its result in some future poll.
|
# If reset is async, we will get its result in some future poll.
|
||||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||||
new_episode: MultiAgentEpisode = active_episodes[env_id]
|
new_episode: Episode = active_episodes[env_id]
|
||||||
if observation_fn:
|
if observation_fn:
|
||||||
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
||||||
agent_obs=resetted_obs,
|
agent_obs=resetted_obs,
|
||||||
|
@ -995,21 +1028,20 @@ def _do_policy_eval(
|
||||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||||
policies: PolicyMap,
|
policies: PolicyMap,
|
||||||
sample_collector,
|
sample_collector,
|
||||||
active_episodes: Dict[EnvID, MultiAgentEpisode],
|
active_episodes: Dict[EnvID, Episode],
|
||||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||||
"""Call compute_actions on collected episode/model data to get next action.
|
"""Call compute_actions on collected episode/model data to get next action.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
|
to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
|
||||||
IDs to lists of PolicyEvalData objects (items in these lists will
|
(items in these lists will be the batch's items for the model
|
||||||
be the batch's items for the model forward pass).
|
forward pass).
|
||||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
policies: Mapping from policy ID to Policy obj.
|
||||||
obj.
|
sample_collector: The SampleCollector object to use.
|
||||||
sample_collector (SampleCollector): The SampleCollector object to use.
|
active_episodes: Mapping of EnvID to its currently active episode.
|
||||||
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping of
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
eval_results: dict of policy to compute_action() outputs.
|
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
eval_results: Dict[PolicyID, TensorStructType] = {}
|
eval_results: Dict[PolicyID, TensorStructType] = {}
|
||||||
|
@ -1052,7 +1084,7 @@ def _process_policy_eval_results(
|
||||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||||
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
|
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
|
||||||
dict]],
|
dict]],
|
||||||
active_episodes: Dict[EnvID, MultiAgentEpisode],
|
active_episodes: Dict[EnvID, Episode],
|
||||||
active_envs: Set[int],
|
active_envs: Set[int],
|
||||||
off_policy_actions: MultiEnvDict,
|
off_policy_actions: MultiEnvDict,
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
|
@ -1065,24 +1097,22 @@ def _process_policy_eval_results(
|
||||||
returns replies to send back to agents in the env.
|
returns replies to send back to agents in the env.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs
|
to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
|
||||||
to lists of PolicyEvalData objects.
|
eval_results: Mapping of policy IDs to list of
|
||||||
eval_results (Dict[PolicyID, List]): Mapping of policy IDs to list of
|
|
||||||
actions, rnn-out states, extra-action-fetches dicts.
|
actions, rnn-out states, extra-action-fetches dicts.
|
||||||
active_episodes (Dict[EnvID, MultiAgentEpisode]): Mapping from
|
active_episodes: Mapping from episode ID to currently ongoing
|
||||||
episode ID to currently ongoing MultiAgentEpisode object.
|
Episode object.
|
||||||
active_envs (Set[int]): Set of non-terminated env ids.
|
active_envs: Set of non-terminated env ids.
|
||||||
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids ->
|
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
|
||||||
off-policy-action, returned by a `BaseEnv.poll()` call.
|
off-policy-action, returned by a `BaseEnv.poll()` call.
|
||||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
|
policies: Mapping from policy ID to Policy.
|
||||||
normalize_actions (bool): Whether to normalize actions to the action
|
normalize_actions: Whether to normalize actions to the action
|
||||||
space's bounds.
|
space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions to the action space's
|
clip_actions: Whether to clip actions to the action space's bounds.
|
||||||
bounds.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
actions_to_send: Nested dict of env id -> agent id -> actions to be
|
Nested dict of env id -> agent id -> actions to be sent to
|
||||||
sent to Env (np.ndarrays).
|
Env (np.ndarrays).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
|
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
|
||||||
|
@ -1098,7 +1128,7 @@ def _process_policy_eval_results(
|
||||||
actions = convert_to_numpy(actions)
|
actions = convert_to_numpy(actions)
|
||||||
|
|
||||||
rnn_out_cols: StateBatch = eval_results[policy_id][1]
|
rnn_out_cols: StateBatch = eval_results[policy_id][1]
|
||||||
pi_info_cols: dict = eval_results[policy_id][2]
|
extra_action_out_cols: dict = eval_results[policy_id][2]
|
||||||
|
|
||||||
# In case actions is a list (representing the 0th dim of a batch of
|
# In case actions is a list (representing the 0th dim of a batch of
|
||||||
# primitive actions), try converting it first.
|
# primitive actions), try converting it first.
|
||||||
|
@ -1107,7 +1137,7 @@ def _process_policy_eval_results(
|
||||||
|
|
||||||
# Store RNN state ins/outs and extra-action fetches to episode.
|
# Store RNN state ins/outs and extra-action fetches to episode.
|
||||||
for f_i, column in enumerate(rnn_out_cols):
|
for f_i, column in enumerate(rnn_out_cols):
|
||||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
extra_action_out_cols["state_out_{}".format(f_i)] = column
|
||||||
|
|
||||||
policy: Policy = _get_or_raise(policies, policy_id)
|
policy: Policy = _get_or_raise(policies, policy_id)
|
||||||
# Split action-component batches into single action rows.
|
# Split action-component batches into single action rows.
|
||||||
|
@ -1127,11 +1157,11 @@ def _process_policy_eval_results(
|
||||||
|
|
||||||
env_id: int = eval_data[i].env_id
|
env_id: int = eval_data[i].env_id
|
||||||
agent_id: AgentID = eval_data[i].agent_id
|
agent_id: AgentID = eval_data[i].agent_id
|
||||||
episode: MultiAgentEpisode = active_episodes[env_id]
|
episode: Episode = active_episodes[env_id]
|
||||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||||
episode._set_last_pi_info(
|
episode._set_last_extra_action_outs(
|
||||||
agent_id, {k: v[i]
|
agent_id, {k: v[i]
|
||||||
for k, v in pi_info_cols.items()})
|
for k, v in extra_action_out_cols.items()})
|
||||||
if env_id in off_policy_actions and \
|
if env_id in off_policy_actions and \
|
||||||
agent_id in off_policy_actions[env_id]:
|
agent_id in off_policy_actions[env_id]:
|
||||||
episode._set_last_action(agent_id,
|
episode._set_last_action(agent_id,
|
||||||
|
|
|
@ -81,7 +81,7 @@ class EchoPolicy(Policy):
|
||||||
return obs_batch.argmax(axis=1), [], {}
|
return obs_batch.argmax(axis=1), [], {}
|
||||||
|
|
||||||
|
|
||||||
class MultiAgentEpisodeEnv(MultiAgentEnv):
|
class EpisodeEnv(MultiAgentEnv):
|
||||||
def __init__(self, episode_length, num):
|
def __init__(self, episode_length, num):
|
||||||
self.agents = [MockEnv3(episode_length) for _ in range(num)]
|
self.agents = [MockEnv3(episode_length) for _ in range(num)]
|
||||||
self.dones = set()
|
self.dones = set()
|
||||||
|
@ -123,9 +123,9 @@ class TestEpisodeLastValues(unittest.TestCase):
|
||||||
ev.sample()
|
ev.sample()
|
||||||
|
|
||||||
def test_multiagent_env(self):
|
def test_multiagent_env(self):
|
||||||
temp_env = MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS)
|
temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS)
|
||||||
ev = RolloutWorker(
|
ev = RolloutWorker(
|
||||||
env_creator=lambda _: MultiAgentEpisodeEnv(NUM_STEPS, NUM_AGENTS),
|
env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS),
|
||||||
policy_spec={
|
policy_spec={
|
||||||
str(agent_id): (EchoPolicy, temp_env.observation_space,
|
str(agent_id): (EchoPolicy, temp_env.observation_space,
|
||||||
temp_env.action_space, {})
|
temp_env.action_space, {})
|
||||||
|
|
|
@ -29,9 +29,9 @@ T = TypeVar("T")
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class WorkerSet:
|
class WorkerSet:
|
||||||
"""Represents a set of RolloutWorkers.
|
"""Set of RolloutWorkers with n @ray.remote workers and one local worker.
|
||||||
|
|
||||||
There must be one local worker copy, and zero or more remote workers.
|
Where n may be 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -43,20 +43,20 @@ class WorkerSet:
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
logdir: Optional[str] = None,
|
logdir: Optional[str] = None,
|
||||||
_setup: bool = True):
|
_setup: bool = True):
|
||||||
"""Create a new WorkerSet and initialize its workers.
|
"""Initializes a WorkerSet instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
|
env_creator: Function that returns env given env config.
|
||||||
that returns env given env config.
|
validate_env: Optional callable to validate the generated
|
||||||
validate_env (Optional[Callable[[EnvType], None]]): Optional
|
environment (only on worker=0).
|
||||||
callable to validate the generated environment (only on
|
policy_class: An optional Policy class. If None, PolicySpecs can be
|
||||||
worker=0).
|
generated automatically by using the Trainer's default class
|
||||||
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
|
of via a given multi-agent policy config dict.
|
||||||
trainer_config (Optional[TrainerConfigDict]): Optional dict that
|
trainer_config: Optional dict that extends the common config of
|
||||||
extends the common config of the Trainer class.
|
the Trainer class.
|
||||||
num_workers (int): Number of remote rollout workers to create.
|
num_workers: Number of remote rollout workers to create.
|
||||||
logdir (Optional[str]): Optional logging directory for workers.
|
logdir: Optional logging directory for workers.
|
||||||
_setup (bool): Whether to setup workers. This is only for testing.
|
_setup: Whether to setup workers. This is only for testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not trainer_config:
|
if not trainer_config:
|
||||||
|
@ -119,15 +119,20 @@ class WorkerSet:
|
||||||
)
|
)
|
||||||
|
|
||||||
def local_worker(self) -> RolloutWorker:
|
def local_worker(self) -> RolloutWorker:
|
||||||
"""Return the local rollout worker."""
|
"""Returns the local rollout worker."""
|
||||||
return self._local_worker
|
return self._local_worker
|
||||||
|
|
||||||
def remote_workers(self) -> List[ActorHandle]:
|
def remote_workers(self) -> List[ActorHandle]:
|
||||||
"""Return a list of remote rollout workers."""
|
"""Returns a list of remote rollout workers."""
|
||||||
return self._remote_workers
|
return self._remote_workers
|
||||||
|
|
||||||
def sync_weights(self, policies: Optional[List[PolicyID]] = None) -> None:
|
def sync_weights(self, policies: Optional[List[PolicyID]] = None) -> None:
|
||||||
"""Syncs weights from the local worker to all remote workers."""
|
"""Syncs model weights from the local worker to all remote workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policies: An optional list of policy IDs to sync for. If None,
|
||||||
|
sync all policies.
|
||||||
|
"""
|
||||||
if self.remote_workers():
|
if self.remote_workers():
|
||||||
weights = ray.put(self.local_worker().get_weights(policies))
|
weights = ray.put(self.local_worker().get_weights(policies))
|
||||||
for e in self.remote_workers():
|
for e in self.remote_workers():
|
||||||
|
@ -136,8 +141,11 @@ class WorkerSet:
|
||||||
def add_workers(self, num_workers: int) -> None:
|
def add_workers(self, num_workers: int) -> None:
|
||||||
"""Creates and adds a number of remote workers to this worker set.
|
"""Creates and adds a number of remote workers to this worker set.
|
||||||
|
|
||||||
|
Can be called several times on the same WorkerSet to add more
|
||||||
|
RolloutWorkers to the set.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_workers (int): The number of remote Workers to add to this
|
num_workers: The number of remote Workers to add to this
|
||||||
WorkerSet.
|
WorkerSet.
|
||||||
"""
|
"""
|
||||||
remote_args = {
|
remote_args = {
|
||||||
|
@ -159,25 +167,36 @@ class WorkerSet:
|
||||||
])
|
])
|
||||||
|
|
||||||
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
|
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
|
||||||
"""Called to change the set of remote workers."""
|
"""Hard overrides the remote workers in this set with the given one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_remote_workers: A list of new RolloutWorkers
|
||||||
|
(as `ActorHandles`) to use as remote workers.
|
||||||
|
"""
|
||||||
self._remote_workers = new_remote_workers
|
self._remote_workers = new_remote_workers
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop all rollout workers."""
|
"""Calls `stop` on all rollout workers (including the local one)."""
|
||||||
try:
|
try:
|
||||||
self.local_worker().stop()
|
self.local_worker().stop()
|
||||||
tids = [w.stop.remote() for w in self.remote_workers()]
|
tids = [w.stop.remote() for w in self.remote_workers()]
|
||||||
ray.get(tids)
|
ray.get(tids)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to stop workers")
|
logger.exception("Failed to stop workers!")
|
||||||
finally:
|
finally:
|
||||||
for w in self.remote_workers():
|
for w in self.remote_workers():
|
||||||
w.__ray_terminate__.remote()
|
w.__ray_terminate__.remote()
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
|
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
|
||||||
"""Apply the given function to each worker instance."""
|
"""Calls the given function with each worker instance as arg.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to call for each worker (as only arg).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of return values of all calls to `func([worker])`.
|
||||||
|
"""
|
||||||
local_result = [func(self.local_worker())]
|
local_result = [func(self.local_worker())]
|
||||||
remote_results = ray.get(
|
remote_results = ray.get(
|
||||||
[w.apply.remote(func) for w in self.remote_workers()])
|
[w.apply.remote(func) for w in self.remote_workers()])
|
||||||
|
@ -186,11 +205,23 @@ class WorkerSet:
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def foreach_worker_with_index(
|
def foreach_worker_with_index(
|
||||||
self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
|
self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
|
||||||
"""Apply the given function to each worker instance.
|
"""Calls `func` with each worker instance and worker idx as args.
|
||||||
|
|
||||||
The index will be passed as the second arg to the given function.
|
The index will be passed as the second arg to the given function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to call for each worker and its index
|
||||||
|
(as args). The local worker has index 0, all remote workers
|
||||||
|
have indices > 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of return values of all calls to `func([worker, idx])`.
|
||||||
|
The first entry in this list are the results of the local
|
||||||
|
worker, followed by all remote workers' results.
|
||||||
"""
|
"""
|
||||||
|
# Local worker: Index=0.
|
||||||
local_result = [func(self.local_worker(), 0)]
|
local_result = [func(self.local_worker(), 0)]
|
||||||
|
# Remote workers: Index > 0.
|
||||||
remote_results = ray.get([
|
remote_results = ray.get([
|
||||||
w.apply.remote(func, i + 1)
|
w.apply.remote(func, i + 1)
|
||||||
for i, w in enumerate(self.remote_workers())
|
for i, w in enumerate(self.remote_workers())
|
||||||
|
@ -199,15 +230,22 @@ class WorkerSet:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
||||||
"""Apply the given function to each worker's (policy, policy_id) tuple.
|
"""Calls `func` with each worker's (policy, PolicyID) tuple.
|
||||||
|
|
||||||
|
Note that in the multi-agent case, each worker may have more than one
|
||||||
|
policy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (callable): A function - taking a Policy and its ID - that is
|
func: A function - taking a Policy and its ID - that is
|
||||||
called on all workers' Policies.
|
called on all workers' Policies.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[any]: The list of return values of func over all workers'
|
The list of return values of func over all workers' policies. The
|
||||||
policies.
|
length of this list is:
|
||||||
|
(num_workers + 1 (local-worker)) *
|
||||||
|
[num policies in the multi-agent config dict].
|
||||||
|
The local workers' results are first, followed by all remote
|
||||||
|
workers' results
|
||||||
"""
|
"""
|
||||||
results = self.local_worker().foreach_policy(func)
|
results = self.local_worker().foreach_policy(func)
|
||||||
ray_gets = []
|
ray_gets = []
|
||||||
|
@ -221,8 +259,8 @@ class WorkerSet:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def trainable_policies(self) -> List[PolicyID]:
|
def trainable_policies(self) -> List[PolicyID]:
|
||||||
"""Return the list of trainable policy ids."""
|
"""Returns the list of trainable policy ids."""
|
||||||
return self.local_worker().foreach_trainable_policy(lambda _, pid: pid)
|
return self.local_worker().policies_to_train
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def foreach_trainable_policy(
|
def foreach_trainable_policy(
|
||||||
|
@ -230,7 +268,7 @@ class WorkerSet:
|
||||||
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
|
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (callable): A function - taking a Policy and its ID - that is
|
func: A function - taking a Policy and its ID - that is
|
||||||
called on all workers' Policies in `worker.policies_to_train`.
|
called on all workers' Policies in `worker.policies_to_train`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -250,7 +288,7 @@ class WorkerSet:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
|
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
|
||||||
"""Apply `func` to all workers' underlying sub environments.
|
"""Calls `func` with all workers' sub environments as args.
|
||||||
|
|
||||||
An "underlying sub environment" is a single clone of an env within
|
An "underlying sub environment" is a single clone of an env within
|
||||||
a vectorized environment.
|
a vectorized environment.
|
||||||
|
@ -258,13 +296,12 @@ class WorkerSet:
|
||||||
gym.Env object.
|
gym.Env object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (Callable[[EnvType], T]): A function - taking an EnvType
|
func: A function - taking an EnvType (normally a gym.Env object)
|
||||||
(normally a gym.Env object) as arg and returning a list of
|
as arg and returning a list of lists of return values, one
|
||||||
return values over sub environments for each worker.
|
value per underlying sub-environment per each worker.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[T]]: The list (workers) of lists (sub environments) of
|
The list (workers) of lists (sub environments) of results.
|
||||||
results.
|
|
||||||
"""
|
"""
|
||||||
local_results = [self.local_worker().foreach_env(func)]
|
local_results = [self.local_worker().foreach_env(func)]
|
||||||
ray_gets = []
|
ray_gets = []
|
||||||
|
@ -276,7 +313,7 @@ class WorkerSet:
|
||||||
def foreach_env_with_context(
|
def foreach_env_with_context(
|
||||||
self,
|
self,
|
||||||
func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
|
func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
|
||||||
"""Apply `func` to all workers' underlying sub environments.
|
"""Call `func` with all workers' sub-environments and env_ctx as args.
|
||||||
|
|
||||||
An "underlying sub environment" is a single clone of an env within
|
An "underlying sub environment" is a single clone of an env within
|
||||||
a vectorized environment.
|
a vectorized environment.
|
||||||
|
@ -284,13 +321,13 @@ class WorkerSet:
|
||||||
as args.
|
as args.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (Callable[[BaseEnv], T]): A function - taking a BaseEnv
|
func: A function - taking a BaseEnv object and an EnvContext as
|
||||||
object as arg and returning a list of return values over envs
|
arg - and returning a list of lists of return values over envs
|
||||||
of the worker.
|
of the worker.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[T]]: The list (workers) of lists (environments) of
|
The list (1 item per workers) of lists (1 item per sub-environment)
|
||||||
results.
|
of results.
|
||||||
"""
|
"""
|
||||||
local_results = [self.local_worker().foreach_env_with_context(func)]
|
local_results = [self.local_worker().foreach_env_with_context(func)]
|
||||||
ray_gets = []
|
ray_gets = []
|
||||||
|
|
|
@ -13,7 +13,7 @@ import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
from ray.rllib.env import BaseEnv
|
from ray.rllib.env import BaseEnv
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
from ray.rllib.evaluation import Episode, RolloutWorker
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
|
||||||
|
@ -28,8 +28,8 @@ parser.add_argument("--stop-iters", type=int, default=2000)
|
||||||
|
|
||||||
class MyCallbacks(DefaultCallbacks):
|
class MyCallbacks(DefaultCallbacks):
|
||||||
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||||
policies: Dict[str, Policy],
|
policies: Dict[str, Policy], episode: Episode,
|
||||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
env_index: int, **kwargs):
|
||||||
# Make sure this episode has just been started (only initial obs
|
# Make sure this episode has just been started (only initial obs
|
||||||
# logged so far).
|
# logged so far).
|
||||||
assert episode.length == 0, \
|
assert episode.length == 0, \
|
||||||
|
@ -41,8 +41,8 @@ class MyCallbacks(DefaultCallbacks):
|
||||||
episode.hist_data["pole_angles"] = []
|
episode.hist_data["pole_angles"] = []
|
||||||
|
|
||||||
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||||
policies: Dict[str, Policy],
|
policies: Dict[str, Policy], episode: Episode,
|
||||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
env_index: int, **kwargs):
|
||||||
# Make sure this episode is ongoing.
|
# Make sure this episode is ongoing.
|
||||||
assert episode.length > 0, \
|
assert episode.length > 0, \
|
||||||
"ERROR: `on_episode_step()` callback should not be called right " \
|
"ERROR: `on_episode_step()` callback should not be called right " \
|
||||||
|
@ -53,7 +53,7 @@ class MyCallbacks(DefaultCallbacks):
|
||||||
episode.user_data["pole_angles"].append(pole_angle)
|
episode.user_data["pole_angles"].append(pole_angle)
|
||||||
|
|
||||||
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||||
policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
policies: Dict[str, Policy], episode: Episode,
|
||||||
env_index: int, **kwargs):
|
env_index: int, **kwargs):
|
||||||
# Make sure this episode is really done.
|
# Make sure this episode is really done.
|
||||||
assert episode.batch_builder.policy_collectors[
|
assert episode.batch_builder.policy_collectors[
|
||||||
|
@ -84,8 +84,8 @@ class MyCallbacks(DefaultCallbacks):
|
||||||
policy, result["sum_actions_in_train_batch"]))
|
policy, result["sum_actions_in_train_batch"]))
|
||||||
|
|
||||||
def on_postprocess_trajectory(
|
def on_postprocess_trajectory(
|
||||||
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
|
self, *, worker: RolloutWorker, episode: Episode, agent_id: str,
|
||||||
agent_id: str, policy_id: str, policies: Dict[str, Policy],
|
policy_id: str, policies: Dict[str, Policy],
|
||||||
postprocessed_batch: SampleBatch,
|
postprocessed_batch: SampleBatch,
|
||||||
original_batches: Dict[str, SampleBatch], **kwargs):
|
original_batches: Dict[str, SampleBatch], **kwargs):
|
||||||
print("postprocessed {} steps".format(postprocessed_batch.count))
|
print("postprocessed {} steps".format(postprocessed_batch.count))
|
||||||
|
|
|
@ -23,7 +23,7 @@ tf1, tf, tfv = try_import_tf()
|
||||||
torch, _ = try_import_torch()
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
from ray.rllib.evaluation import Episode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
prev_reward: Optional[TensorStructType] = None,
|
prev_reward: Optional[TensorStructType] = None,
|
||||||
info: dict = None,
|
info: dict = None,
|
||||||
input_dict: Optional[SampleBatch] = None,
|
input_dict: Optional[SampleBatch] = None,
|
||||||
episode: Optional["MultiAgentEpisode"] = None,
|
episode: Optional["Episode"] = None,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
# Kwars placeholder for future compatibility.
|
# Kwars placeholder for future compatibility.
|
||||||
|
@ -238,7 +238,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
|
input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
|
||||||
explore: bool = None,
|
explore: bool = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
episodes: Optional[List["Episode"]] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||||
"""Computes actions from collected samples (across multiple-agents).
|
"""Computes actions from collected samples (across multiple-agents).
|
||||||
|
@ -300,7 +300,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
prev_reward_batch: Union[List[TensorStructType],
|
prev_reward_batch: Union[List[TensorStructType],
|
||||||
TensorStructType] = None,
|
TensorStructType] = None,
|
||||||
info_batch: Optional[Dict[str, list]] = None,
|
info_batch: Optional[Dict[str, list]] = None,
|
||||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
episodes: Optional[List["Episode"]] = None,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
|
@ -313,7 +313,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
prev_action_batch: Batch of previous action values.
|
prev_action_batch: Batch of previous action values.
|
||||||
prev_reward_batch: Batch of previous rewards.
|
prev_reward_batch: Batch of previous rewards.
|
||||||
info_batch: Batch of info objects.
|
info_batch: Batch of info objects.
|
||||||
episodes: List of MultiAgentEpisodes, one for each obs in
|
episodes: List of Episode objects, one for each obs in
|
||||||
obs_batch. This provides access to all of the internal
|
obs_batch. This provides access to all of the internal
|
||||||
episode state, which may be useful for model-based or
|
episode state, which may be useful for model-based or
|
||||||
multi-agent algorithms.
|
multi-agent algorithms.
|
||||||
|
@ -377,7 +377,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
sample_batch: SampleBatch,
|
sample_batch: SampleBatch,
|
||||||
other_agent_batches: Optional[Dict[AgentID, Tuple[
|
other_agent_batches: Optional[Dict[AgentID, Tuple[
|
||||||
"Policy", SampleBatch]]] = None,
|
"Policy", SampleBatch]]] = None,
|
||||||
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch:
|
episode: Optional["Episode"] = None) -> SampleBatch:
|
||||||
"""Implements algorithm-specific trajectory postprocessing.
|
"""Implements algorithm-specific trajectory postprocessing.
|
||||||
|
|
||||||
This will be called on each trajectory fragment computed during policy
|
This will be called on each trajectory fragment computed during policy
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||||
TrainerConfigDict
|
TrainerConfigDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode # noqa
|
from ray.rllib.evaluation.episode import Episode # noqa
|
||||||
|
|
||||||
jax, _ = try_import_jax()
|
jax, _ = try_import_jax()
|
||||||
torch, _ = try_import_torch()
|
torch, _ = try_import_torch()
|
||||||
|
@ -39,7 +39,7 @@ def build_policy_class(
|
||||||
str, TensorType]]] = None,
|
str, TensorType]]] = None,
|
||||||
postprocess_fn: Optional[Callable[[
|
postprocess_fn: Optional[Callable[[
|
||||||
Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
|
Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
|
||||||
"MultiAgentEpisode"]
|
"Episode"]
|
||||||
], SampleBatch]] = None,
|
], SampleBatch]] = None,
|
||||||
extra_action_out_fn: Optional[Callable[[
|
extra_action_out_fn: Optional[Callable[[
|
||||||
Policy, Dict[str, TensorType], List[TensorType], ModelV2,
|
Policy, Dict[str, TensorType], List[TensorType], ModelV2,
|
||||||
|
@ -98,7 +98,7 @@ def build_policy_class(
|
||||||
overrides. If None, uses only(!) the user-provided
|
overrides. If None, uses only(!) the user-provided
|
||||||
PartialTrainerConfigDict as dict for this Policy.
|
PartialTrainerConfigDict as dict for this Policy.
|
||||||
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
||||||
Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]],
|
Optional[Dict[Any, SampleBatch]], Optional["Episode"]],
|
||||||
SampleBatch]]): Optional callable for post-processing experience
|
SampleBatch]]): Optional callable for post-processing experience
|
||||||
batches (called after the super's `postprocess_trajectory` method).
|
batches (called after the super's `postprocess_trajectory` method).
|
||||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||||
|
|
|
@ -28,7 +28,7 @@ from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
|
||||||
TensorType, TrainerConfigDict
|
TensorType, TrainerConfigDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
from ray.rllib.evaluation import Episode
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -277,7 +277,7 @@ class TFPolicy(Policy):
|
||||||
input_dict: Union[SampleBatch, Dict[str, TensorType]],
|
input_dict: Union[SampleBatch, Dict[str, TensorType]],
|
||||||
explore: bool = None,
|
explore: bool = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
episodes: Optional[List["Episode"]] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ class TFPolicy(Policy):
|
||||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||||
info_batch: Optional[Dict[str, list]] = None,
|
info_batch: Optional[Dict[str, list]] = None,
|
||||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
episodes: Optional[List["Episode"]] = None,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \
|
||||||
TrainerConfigDict
|
TrainerConfigDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode
|
from ray.rllib.evaluation import Episode
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ def build_tf_policy(
|
||||||
TrainerConfigDict]] = None,
|
TrainerConfigDict]] = None,
|
||||||
postprocess_fn: Optional[Callable[[
|
postprocess_fn: Optional[Callable[[
|
||||||
Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
|
Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
|
||||||
Optional["MultiAgentEpisode"]
|
Optional["Episode"]
|
||||||
], SampleBatch]] = None,
|
], SampleBatch]] = None,
|
||||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||||
str, TensorType]]] = None,
|
str, TensorType]]] = None,
|
||||||
|
@ -107,7 +107,7 @@ def build_tf_policy(
|
||||||
overrides. If None, uses only(!) the user-provided
|
overrides. If None, uses only(!) the user-provided
|
||||||
PartialTrainerConfigDict as dict for this Policy.
|
PartialTrainerConfigDict as dict for this Policy.
|
||||||
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
||||||
Optional[Dict[AgentID, SampleBatch]], MultiAgentEpisode], None]]):
|
Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
|
||||||
Optional callable for post-processing experience batches (called
|
Optional callable for post-processing experience batches (called
|
||||||
after the parent class' `postprocess_trajectory` method).
|
after the parent class' `postprocess_trajectory` method).
|
||||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||||
|
|
|
@ -31,7 +31,7 @@ from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
|
||||||
TensorStructType, TrainerConfigDict
|
TensorStructType, TrainerConfigDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode # noqa
|
from ray.rllib.evaluation import Episode # noqa
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ class TorchPolicy(Policy):
|
||||||
prev_reward_batch: Union[List[TensorStructType],
|
prev_reward_batch: Union[List[TensorStructType],
|
||||||
TensorStructType] = None,
|
TensorStructType] = None,
|
||||||
info_batch: Optional[Dict[str, list]] = None,
|
info_batch: Optional[Dict[str, list]] = None,
|
||||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
episodes: Optional[List["Episode"]] = None,
|
||||||
explore: Optional[bool] = None,
|
explore: Optional[bool] = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
|
|
|
@ -7,7 +7,7 @@ import ray
|
||||||
from ray.tune.registry import register_env
|
from ray.tune.registry import register_env
|
||||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||||
from ray.rllib.agents.pg import PGTrainer
|
from ray.rllib.agents.pg import PGTrainer
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
||||||
|
@ -336,8 +336,8 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||||
# Pretend we did a model-based rollout and want to return
|
# Pretend we did a model-based rollout and want to return
|
||||||
# the extra trajectory.
|
# the extra trajectory.
|
||||||
env_id = episodes[0].env_id
|
env_id = episodes[0].env_id
|
||||||
fake_eps = MultiAgentEpisode(
|
fake_eps = Episode(episodes[0].policy_map,
|
||||||
episodes[0].policy_map, episodes[0].policy_mapping_fn,
|
episodes[0].policy_mapping_fn,
|
||||||
lambda: None, lambda x: None, env_id)
|
lambda: None, lambda x: None, env_id)
|
||||||
builder = get_global_worker().sampler.sample_collector
|
builder = get_global_worker().sampler.sample_collector
|
||||||
agent_id = "extra_0"
|
agent_id = "extra_0"
|
||||||
|
|
Loading…
Add table
Reference in a new issue