[RLlib] Batch-size for truncate_episode batch_mode should be confgurable in agent-steps (rather than env-steps), if needed. (#12420)

This commit is contained in:
Sven Mika 2020-12-09 01:41:45 +01:00 committed by GitHub
parent fd4e025da6
commit e40b14d255
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 384 additions and 132 deletions

View file

@ -1108,6 +1108,13 @@ py_test(
# srcs = ["evaluation/tests/test_trajectory_view_api.py"] # srcs = ["evaluation/tests/test_trajectory_view_api.py"]
#) #)
py_test(
name = "evaluation/tests/test_rollout_worker",
tags = ["evaluation"],
size = "medium",
srcs = ["evaluation/tests/test_rollout_worker.py"]
)
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Optimizers and Memories # Optimizers and Memories
@ -1411,13 +1418,6 @@ py_test(
args = ["TestRolloutLearntPolicy"] args = ["TestRolloutLearntPolicy"]
) )
py_test(
name = "tests/test_rollout_worker",
tags = ["tests_dir", "tests_dir_R"],
size = "medium",
srcs = ["tests/test_rollout_worker.py"]
)
py_test( py_test(
name = "tests/test_supported_multi_agent_pg", name = "tests/test_supported_multi_agent_pg",
main = "tests/test_supported_multi_agent.py", main = "tests/test_supported_multi_agent.py",

View file

@ -38,17 +38,20 @@ def execution_plan(workers, config):
# allowing for extremely large experience batches to be used. # allowing for extremely large experience batches to be used.
train_op = ( train_op = (
rollouts.combine( rollouts.combine(
ConcatBatches(min_batch_size=config["microbatch_size"])) ConcatBatches(
min_batch_size=config["microbatch_size"],
count_steps_by=config["multiagent"]["count_steps_by"]))
.for_each(ComputeGradients(workers)) # (grads, info) .for_each(ComputeGradients(workers)) # (grads, info)
.batch(num_microbatches) # List[(grads, info)] .batch(num_microbatches) # List[(grads, info)]
.for_each(AverageGradients()) # (avg_grads, info) .for_each(AverageGradients()) # (avg_grads, info)
.for_each(ApplyGradients(workers))) .for_each(ApplyGradients(workers)))
else: else:
# In normal mode, we execute one SGD step per each train batch. # In normal mode, we execute one SGD step per each train batch.
train_op = rollouts \ train_op = rollouts.combine(
.combine(ConcatBatches( ConcatBatches(
min_batch_size=config["train_batch_size"])) \ min_batch_size=config["train_batch_size"],
.for_each(TrainOneStep(workers)) count_steps_by=config["multiagent"][
"count_steps_by"])).for_each(TrainOneStep(workers))
return StandardMetricsReporting(train_op, workers, config) return StandardMetricsReporting(train_op, workers, config)

View file

@ -221,7 +221,10 @@ def gather_experiences_directly(workers, config):
replay_proportion=config["replay_proportion"])) \ replay_proportion=config["replay_proportion"])) \
.flatten() \ .flatten() \
.combine( .combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
return train_batches return train_batches

View file

@ -56,7 +56,10 @@ def execution_plan(workers, config):
replay_op = Replay(local_buffer=replay_buffer) \ replay_op = Replay(local_buffer=replay_buffer) \
.combine( .combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) \ ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)) \
.for_each(TrainOneStep(workers)) .for_each(TrainOneStep(workers))
train_op = Concurrently( train_op = Concurrently(

View file

@ -244,7 +244,10 @@ def execution_plan(workers: WorkerSet,
SelectExperiences(workers.trainable_policies())) SelectExperiences(workers.trainable_policies()))
# Concatenate the SampleBatches into one. # Concatenate the SampleBatches into one.
rollouts = rollouts.combine( rollouts = rollouts.combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
# Standardize advantages. # Standardize advantages.
rollouts = rollouts.for_each(StandardizeFields(["advantages"])) rollouts = rollouts.for_each(StandardizeFields(["advantages"]))

View file

@ -73,7 +73,7 @@ class TestPPO(unittest.TestCase):
def test_ppo_compilation_and_lr_schedule(self): def test_ppo_compilation_and_lr_schedule(self):
"""Test whether a PPOTrainer can be built with all frameworks.""" """Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG) config = copy.deepcopy(ppo.DEFAULT_CONFIG)
# for checking lr-schedule correctness # For checking lr-schedule correctness.
config["callbacks"] = MyCallbacks config["callbacks"] = MyCallbacks
config["num_workers"] = 1 config["num_workers"] = 1

View file

@ -109,7 +109,10 @@ def execution_plan(workers, config):
train_op = Replay(local_buffer=replay_buffer) \ train_op = Replay(local_buffer=replay_buffer) \
.combine( .combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) \ ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"]
)) \
.for_each(TrainOneStep(workers)) \ .for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork( .for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"])) workers, config["target_network_update_freq"]))

View file

@ -75,10 +75,18 @@ COMMON_CONFIG: TrainerConfigDict = {
# The dataflow here can vary per algorithm. For example, PPO further # The dataflow here can vary per algorithm. For example, PPO further
# divides the train batch into minibatches for multi-epoch SGD. # divides the train batch into minibatches for multi-epoch SGD.
"rollout_fragment_length": 200, "rollout_fragment_length": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes" to # How to build per-Sampler (RolloutWorker) batches, which are then
# `rollout_fragment_length` length unrolls. Episode truncation guarantees # usually concat'd to form the train batch. Note that "steps" below can
# evenly sized batches, but increases variance as the reward-to-go will # mean different things (either env- or agent-steps) and depends on the
# need to be estimated at truncation boundaries. # `count_steps_by` (multiagent) setting below.
# truncate_episodes: Each produced batch (when calling
# RolloutWorker.sample()) will contain exactly `rollout_fragment_length`
# steps. This mode guarantees evenly sized batches, but increases
# variance as the future return must now be estimated at truncation
# boundaries.
# complete_episodes: Each unroll happens exactly over one episode, from
# beginning to end. Data collection will not stop unless the episode
# terminates or a configured horizon (hard or soft) is hit.
"batch_mode": "truncate_episodes", "batch_mode": "truncate_episodes",
# === Settings for the Trainer process === # === Settings for the Trainer process ===
@ -357,6 +365,13 @@ COMMON_CONFIG: TrainerConfigDict = {
# agents it controls at that timestep. When replay_mode=independent, # agents it controls at that timestep. When replay_mode=independent,
# transitions are replayed independently per policy. # transitions are replayed independently per policy.
"replay_mode": "independent", "replay_mode": "independent",
# Which metric to use as the "batch size" when building a
# MultiAgentBatch. The two supported values are:
# env_steps: Count each time the env is "stepped" (no matter how many
# multi-agent actions are passed/how many multi-agent observations
# have been returned in the previous step).
# agent_steps: Count each individual agent step as one step.
"count_steps_by": "env_steps",
}, },
# === Logger === # === Logger ===
@ -1081,6 +1096,20 @@ class Trainer(Trainable):
config["model"]["lstm_use_prev_action"] = prev_a_r config["model"]["lstm_use_prev_action"] = prev_a_r
config["model"]["lstm_use_prev_reward"] = prev_a_r config["model"]["lstm_use_prev_reward"] = prev_a_r
# Check batching/sample collection settings.
if config["batch_mode"] not in [
"truncate_episodes", "complete_episodes"
]:
raise ValueError("`batch_mode` must be one of [truncate_episodes|"
"complete_episodes]! Got {}".format(
config["batch_mode"]))
if config["multiagent"].get("count_steps_by", "env_steps") not in \
["env_steps", "agent_steps"]:
raise ValueError(
"`count_steps_by` must be one of [env_steps|agent_steps]! "
"Got {}".format(config["multiagent"]["count_steps_by"]))
def _try_recover(self): def _try_recover(self):
"""Try to identify and remove any unhealthy workers. """Try to identify and remove any unhealthy workers.

View file

@ -22,10 +22,11 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
# Combine experiences batches until we hit `train_batch_size` in size. # Combine experiences batches until we hit `train_batch_size` in size.
# Then, train the policy on those experiences and update the workers. # Then, train the policy on those experiences and update the workers.
train_op = rollouts \ train_op = rollouts.combine(
.combine(ConcatBatches( ConcatBatches(
min_batch_size=config["train_batch_size"])) \ min_batch_size=config["train_batch_size"],
.for_each(TrainOneStep(workers)) count_steps_by=config["multiagent"]["count_steps_by"],
)).for_each(TrainOneStep(workers))
# Add on the standard episode reward, etc. metrics reporting. This returns # Add on the standard episode reward, etc. metrics reporting. This returns
# a LocalIterator[metrics_dict] representing metrics for each train step. # a LocalIterator[metrics_dict] representing metrics for each train step.

View file

@ -164,11 +164,12 @@ def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync") rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["simple_optimizer"]: if config["simple_optimizer"]:
train_op = rollouts \ train_op = rollouts.combine(
.combine(ConcatBatches( ConcatBatches(
min_batch_size=config["train_batch_size"])) \ min_batch_size=config["train_batch_size"],
.for_each(TrainOneStep( count_steps_by=config["multiagent"]["count_steps_by"],
workers, num_sgd_iter=config["num_sgd_iter"])) )).for_each(
TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
else: else:
replay_buffer = SimpleReplayBuffer(config["buffer_size"]) replay_buffer = SimpleReplayBuffer(config["buffer_size"])
@ -178,7 +179,10 @@ def execution_plan(workers, config):
replay_op = Replay(local_buffer=replay_buffer) \ replay_op = Replay(local_buffer=replay_buffer) \
.filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \ .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \
.combine( .combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) \ ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)) \
.for_each(TrainOneStep( .for_each(TrainOneStep(
workers, num_sgd_iter=config["num_sgd_iter"])) workers, num_sgd_iter=config["num_sgd_iter"]))

View file

@ -110,11 +110,37 @@ class _SampleCollector(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def total_env_steps(self) -> int: def total_env_steps(self) -> int:
"""Returns total number of steps taken in the env (sum of all agents). """Returns total number of env-steps taken so far.
Thereby, a step in an N-agent multi-agent environment counts as only 1
for this metric. The returned count contains everything that has not
been built yet (and returned as MultiAgentBatches by the
`try_build_truncated_episode_multi_agent_batch` or
`postprocess_episode(build=True)` methods). After such build, this
counter is reset to 0.
Returns: Returns:
int: The number of steps taken in total in the environment over all int: The number of env-steps taken in total in the environment(s)
agents. so far.
"""
raise NotImplementedError
@abstractmethod
def total_agent_steps(self) -> int:
"""Returns total number of (individual) agent-steps taken so far.
Thereby, a step in an N-agent multi-agent environment counts as N.
If less than N agents have stepped (because some agents were not
required to send actions), the count will be increased by less than N.
The returned count contains everything that has not been built yet
(and returned as MultiAgentBatches by the
`try_build_truncated_episode_multi_agent_batch` or
`postprocess_episode(build=True)` methods). After such build, this
counter is reset to 0.
Returns:
int: The number of (individual) agent-steps taken in total in the
environment(s) so far.
""" """
raise NotImplementedError raise NotImplementedError

View file

@ -51,7 +51,7 @@ class _AgentCollector:
self.episode_id = None self.episode_id = None
# The simple timestep count for this agent. Gets increased by one # The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added. # each time a (non-initial!) observation is added.
self.count = 0 self.agent_steps = 0
def add_init_obs(self, episode_id: EpisodeID, agent_index: int, def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType) -> None: env_id: EnvID, t: int, init_obs: TensorType) -> None:
@ -105,7 +105,7 @@ class _AgentCollector:
if k not in self.buffers: if k not in self.buffers:
self._build_buffers(single_row=values) self._build_buffers(single_row=values)
self.buffers[k].append(v) self.buffers[k].append(v)
self.count += 1 self.agent_steps += 1
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
"""Builds a SampleBatch from the thus-far collected agent data. """Builds a SampleBatch from the thus-far collected agent data.
@ -183,7 +183,7 @@ class _AgentCollector:
if self.shift_before > 0: if self.shift_before > 0:
for k, data in self.buffers.items(): for k, data in self.buffers.items():
self.buffers[k] = data[-self.shift_before:] self.buffers[k] = data[-self.shift_before:]
self.count = 0 self.agent_steps = 0
return batch return batch
@ -238,7 +238,7 @@ class _PolicyCollector:
# NOTE: This is not an env-step count (across n agents). AgentA and # NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both # agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n. # doing n steps would increase the count by 2*n.
self.count = 0 self.agent_steps = 0
def add_postprocessed_batch_for_training( def add_postprocessed_batch_for_training(
self, batch: SampleBatch, self, batch: SampleBatch,
@ -246,9 +246,9 @@ class _PolicyCollector:
"""Adds a postprocessed SampleBatch (single agent) to our buffers. """Adds a postprocessed SampleBatch (single agent) to our buffers.
Args: Args:
batch (SampleBatch): A single agent (one trajectory) SampleBatch batch (SampleBatch): An individual agent's (one trajectory)
to be added to the Policy's buffers. SampleBatch to be added to the Policy's buffers.
view_requirements (DViewRequirementsDict): The view view_requirements (ViewRequirementsDict): The view
requirements for the policy. This is so we know, whether a requirements for the policy. This is so we know, whether a
view-column needs to be copied at all (not needed for view-column needs to be copied at all (not needed for
training). training).
@ -261,7 +261,7 @@ class _PolicyCollector:
view_requirements[view_col].used_for_training: view_requirements[view_col].used_for_training:
self.buffers[view_col].extend(data) self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count. # Add the agent's trajectory length to our count.
self.count += batch.count self.agent_steps += batch.count
def build(self): def build(self):
"""Builds a SampleBatch for this policy from the collected data. """Builds a SampleBatch for this policy from the collected data.
@ -277,8 +277,8 @@ class _PolicyCollector:
assert SampleBatch.UNROLL_ID in batch.data assert SampleBatch.UNROLL_ID in batch.data
# Clear buffers for future samples. # Clear buffers for future samples.
self.buffers.clear() self.buffers.clear()
# Reset count to 0. # Reset agent steps to 0.
self.count = 0 self.agent_steps = 0
return batch return batch
@ -288,7 +288,11 @@ class _PolicyCollectorGroup:
pid: _PolicyCollector() pid: _PolicyCollector()
for pid in policy_map.keys() for pid in policy_map.keys()
} }
self.count = 0 # Total env-steps (1 env-step=up to N agents stepped).
self.env_steps = 0
# Total agent steps (1 agent-step=1 individual agent (out of N)
# stepped).
self.agent_steps = 0
class _SimpleListCollector(_SampleCollector): class _SimpleListCollector(_SampleCollector):
@ -305,7 +309,8 @@ class _SimpleListCollector(_SampleCollector):
clip_rewards: Union[bool, float], clip_rewards: Union[bool, float],
callbacks: "DefaultCallbacks", callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = True, multiple_episodes_in_batch: bool = True,
rollout_fragment_length: int = 200): rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps"):
"""Initializes a _SimpleListCollector instance. """Initializes a _SimpleListCollector instance.
Args: Args:
@ -314,6 +319,10 @@ class _SimpleListCollector(_SampleCollector):
clip_rewards (Union[bool, float]): Whether to clip rewards before clip_rewards (Union[bool, float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip. postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks. callbacks (DefaultCallbacks): RLlib callbacks.
multiple_episodes_in_batch (bool): Whether it's allowed to pack
multiple episodes into the same built batch.
rollout_fragment_length (int): The
""" """
self.policy_map = policy_map self.policy_map = policy_map
@ -321,6 +330,7 @@ class _SimpleListCollector(_SampleCollector):
self.callbacks = callbacks self.callbacks = callbacks
self.multiple_episodes_in_batch = multiple_episodes_in_batch self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.rollout_fragment_length = rollout_fragment_length self.rollout_fragment_length = rollout_fragment_length
self.count_steps_by = count_steps_by
self.large_batch_threshold: int = max( self.large_batch_threshold: int = max(
1000, rollout_fragment_length * 1000, rollout_fragment_length *
10) if rollout_fragment_length != float("inf") else 5000 10) if rollout_fragment_length != float("inf") else 5000
@ -340,8 +350,10 @@ class _SimpleListCollector(_SampleCollector):
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()} self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
# Maps episode ID to the (non-built) env steps taken in this episode. # Maps episode ID to the (non-built) env steps taken in this episode.
self.episode_steps: Dict[EpisodeID, int] = \ self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
collections.defaultdict(int) # Maps episode ID to the (non-built) individual agent steps in this
# episode.
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
# Maps episode ID to MultiAgentEpisode. # Maps episode ID to MultiAgentEpisode.
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {} self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
@ -351,15 +363,17 @@ class _SimpleListCollector(_SampleCollector):
self.episode_steps[episode_id] += 1 self.episode_steps[episode_id] += 1
episode.length += 1 episode.length += 1
assert episode.batch_builder is not None assert episode.batch_builder is not None
env_steps = episode.batch_builder.count env_steps = episode.batch_builder.env_steps
num_observations = sum( num_individual_observations = sum(
c.count for c in episode.batch_builder.policy_collectors.values()) c.agent_steps
for c in episode.batch_builder.policy_collectors.values())
if num_observations > self.large_batch_threshold and \ if num_individual_observations > self.large_batch_threshold and \
log_once("large_batch_warning"): log_once("large_batch_warning"):
logger.warning( logger.warning(
"More than {} observations in {} env steps for " "More than {} observations in {} env steps for "
"episode {} ".format(num_observations, env_steps, episode_id) + "episode {} ".format(num_individual_observations, env_steps,
episode_id) +
"are buffered in the sampler. If this is more than you " "are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your " "expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some point. " "environment correctly and that it terminates at some point. "
@ -412,6 +426,8 @@ class _SimpleListCollector(_SampleCollector):
assert self.agent_key_to_policy_id[agent_key] == policy_id assert self.agent_key_to_policy_id[agent_key] == policy_id
assert agent_key in self.agent_collectors assert agent_key in self.agent_collectors
self.agent_steps[episode_id] += 1
# Include the current agent id for multi-agent algorithms. # Include the current agent id for multi-agent algorithms.
if agent_id != _DUMMY_AGENT_ID: if agent_id != _DUMMY_AGENT_ID:
values["agent_id"] = agent_id values["agent_id"] = agent_id
@ -424,7 +440,18 @@ class _SimpleListCollector(_SampleCollector):
@override(_SampleCollector) @override(_SampleCollector)
def total_env_steps(self) -> int: def total_env_steps(self) -> int:
return sum(a.count for a in self.agent_collectors.values()) # Add the non-built ongoing-episode env steps + the already built
# env-steps.
return sum(self.episode_steps.values()) + sum(
pg.env_steps for pg in self.policy_collector_groups.values())
@override(_SampleCollector)
def total_agent_steps(self) -> int:
# Add the non-built ongoing-episode agent steps (still in the agent
# collectors) + the already built agent steps.
return sum(a.agent_steps for a in self.agent_collectors.values()) + \
sum(pg.agent_steps for pg in
self.policy_collector_groups.values())
@override(_SampleCollector) @override(_SampleCollector)
def get_inference_input_dict(self, policy_id: PolicyID) -> \ def get_inference_input_dict(self, policy_id: PolicyID) -> \
@ -463,11 +490,12 @@ class _SimpleListCollector(_SampleCollector):
return input_dict return input_dict
@override(_SampleCollector) @override(_SampleCollector)
def postprocess_episode(self, def postprocess_episode(
episode: MultiAgentEpisode, self,
is_done: bool = False, episode: MultiAgentEpisode,
check_dones: bool = False, is_done: bool = False,
build: bool = False) -> None: check_dones: bool = False,
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
episode_id = episode.episode_id episode_id = episode.episode_id
policy_collector_group = episode.batch_builder policy_collector_group = episode.batch_builder
@ -478,7 +506,7 @@ class _SimpleListCollector(_SampleCollector):
pre_batches = {} pre_batches = {}
for (eps_id, agent_id), collector in self.agent_collectors.items(): for (eps_id, agent_id), collector in self.agent_collectors.items():
# Build only if there is data and agent is part of given episode. # Build only if there is data and agent is part of given episode.
if collector.count == 0 or eps_id != episode_id: if collector.agent_steps == 0 or eps_id != episode_id:
continue continue
pid = self.agent_key_to_policy_id[(eps_id, agent_id)] pid = self.agent_key_to_policy_id[(eps_id, agent_id)]
policy = self.policy_map[pid] policy = self.policy_map[pid]
@ -559,16 +587,19 @@ class _SimpleListCollector(_SampleCollector):
post_batch, policy.view_requirements) post_batch, policy.view_requirements)
env_steps = self.episode_steps[episode_id] env_steps = self.episode_steps[episode_id]
policy_collector_group.count += env_steps policy_collector_group.env_steps += env_steps
agent_steps = self.agent_steps[episode_id]
policy_collector_group.agent_steps += agent_steps
if is_done: if is_done:
del self.episode_steps[episode_id] del self.episode_steps[episode_id]
del self.agent_steps[episode_id]
del self.episodes[episode_id] del self.episodes[episode_id]
# Make PolicyCollectorGroup available for more agent batches in # Make PolicyCollectorGroup available for more agent batches in
# other episodes. Do not reset count to 0. # other episodes. Do not reset count to 0.
self.policy_collector_groups.append(policy_collector_group) self.policy_collector_groups.append(policy_collector_group)
else: else:
self.episode_steps[episode_id] = 0 self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0
# Build a MultiAgentBatch from the episode and return. # Build a MultiAgentBatch from the episode and return.
if build: if build:
@ -579,14 +610,15 @@ class _SimpleListCollector(_SampleCollector):
ma_batch = {} ma_batch = {}
for pid, collector in episode.batch_builder.policy_collectors.items(): for pid, collector in episode.batch_builder.policy_collectors.items():
if collector.count > 0: if collector.agent_steps > 0:
ma_batch[pid] = collector.build() ma_batch[pid] = collector.build()
# Create the batch. # Create the batch.
ma_batch = MultiAgentBatch.wrap_as_needed( ma_batch = MultiAgentBatch.wrap_as_needed(
ma_batch, env_steps=episode.batch_builder.count) ma_batch, env_steps=episode.batch_builder.env_steps)
# PolicyCollectorGroup is empty. # PolicyCollectorGroup is empty.
episode.batch_builder.count = 0 episode.batch_builder.env_steps = 0
episode.batch_builder.agent_steps = 0
return ma_batch return ma_batch
@ -595,16 +627,26 @@ class _SimpleListCollector(_SampleCollector):
List[Union[MultiAgentBatch, SampleBatch]]: List[Union[MultiAgentBatch, SampleBatch]]:
batches = [] batches = []
# Loop through ongoing episodes and see whether their length plus # Loop through ongoing episodes and see whether their length plus
# what's already in the policy collectors reaches the fragment-len. # what's already in the policy collectors reaches the fragment-len
# (abiding to the unit used: env-steps or agent-steps).
for episode_id, episode in self.episodes.items(): for episode_id, episode in self.episodes.items():
env_steps = episode.batch_builder.count + \ # Measure batch size in env-steps.
self.episode_steps[episode_id] if self.count_steps_by == "env_steps":
built_steps = episode.batch_builder.env_steps
ongoing_steps = self.episode_steps[episode_id]
# Measure batch-size in agent-steps.
else:
built_steps = episode.batch_builder.agent_steps
ongoing_steps = self.agent_steps[episode_id]
# Reached the fragment-len -> We should build an MA-Batch. # Reached the fragment-len -> We should build an MA-Batch.
if env_steps >= self.rollout_fragment_length: if built_steps + ongoing_steps >= self.rollout_fragment_length:
assert env_steps == self.rollout_fragment_length if self.count_steps_by != "agent_steps":
assert built_steps + ongoing_steps == \
self.rollout_fragment_length
# If we reached the fragment-len only because of `episode_id` # If we reached the fragment-len only because of `episode_id`
# (still ongoing) -> postprocess `episode_id` first. # (still ongoing) -> postprocess `episode_id` first.
if episode.batch_builder.count < self.rollout_fragment_length: if built_steps < self.rollout_fragment_length:
self.postprocess_episode(episode, is_done=False) self.postprocess_episode(episode, is_done=False)
# Build the MA-batch and return. # Build the MA-batch and return.
batch = self._build_multi_agent_batch(episode=episode) batch = self._build_multi_agent_batch(episode=episode)

View file

@ -143,6 +143,7 @@ class RolloutWorker(ParallelIteratorWorker):
policies_to_train: Optional[List[PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None,
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None, tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
rollout_fragment_length: int = 100, rollout_fragment_length: int = 100,
count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes", batch_mode: str = "truncate_episodes",
episode_horizon: int = None, episode_horizon: int = None,
preprocessor_pref: str = "deepmind", preprocessor_pref: str = "deepmind",
@ -208,8 +209,11 @@ class RolloutWorker(ParallelIteratorWorker):
tf_session_creator (Optional[Callable[[], tf1.Session]]): A tf_session_creator (Optional[Callable[[], tf1.Session]]): A
function that returns a TF session. This is optional and only function that returns a TF session. This is optional and only
useful with TFPolicy. useful with TFPolicy.
rollout_fragment_length (int): The target number of env transitions rollout_fragment_length (int): The target number of steps
to include in each sample batch returned from this worker. (maesured in `count_steps_by`) to include in each sample
batch returned from this worker.
count_steps_by (str): The unit in which to count fragment
lengths. One of env_steps or agent_steps.
batch_mode (str): One of the following batch modes: batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch "truncate_episodes": Each call to sample() will return a batch
of at most `rollout_fragment_length * num_envs` in size. of at most `rollout_fragment_length * num_envs` in size.
@ -356,6 +360,7 @@ class RolloutWorker(ParallelIteratorWorker):
raise ValueError("Policy mapping function not callable?") raise ValueError("Policy mapping function not callable?")
self.env_creator: Callable[[EnvContext], EnvType] = env_creator self.env_creator: Callable[[EnvContext], EnvType] = env_creator
self.rollout_fragment_length: int = rollout_fragment_length * num_envs self.rollout_fragment_length: int = rollout_fragment_length * num_envs
self.count_steps_by: str = count_steps_by
self.batch_mode: str = batch_mode self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = True self.preprocessing_enabled: bool = True
@ -570,6 +575,7 @@ class RolloutWorker(ParallelIteratorWorker):
obs_filters=self.filters, obs_filters=self.filters,
clip_rewards=clip_rewards, clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length, rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
callbacks=self.callbacks, callbacks=self.callbacks,
horizon=episode_horizon, horizon=episode_horizon,
multiple_episodes_in_batch=pack, multiple_episodes_in_batch=pack,
@ -593,6 +599,7 @@ class RolloutWorker(ParallelIteratorWorker):
obs_filters=self.filters, obs_filters=self.filters,
clip_rewards=clip_rewards, clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length, rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
callbacks=self.callbacks, callbacks=self.callbacks,
horizon=episode_horizon, horizon=episode_horizon,
multiple_episodes_in_batch=pack, multiple_episodes_in_batch=pack,
@ -636,7 +643,9 @@ class RolloutWorker(ParallelIteratorWorker):
self.rollout_fragment_length)) self.rollout_fragment_length))
batches = [self.input_reader.next()] batches = [self.input_reader.next()]
steps_so_far = batches[0].count steps_so_far = batches[0].count if \
self.count_steps_by == "env_steps" else \
batches[0].agent_steps()
# In truncate_episodes mode, never pull more than 1 batch per env. # In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size. # This avoids over-running the target batch size.
@ -648,7 +657,9 @@ class RolloutWorker(ParallelIteratorWorker):
while (steps_so_far < self.rollout_fragment_length while (steps_so_far < self.rollout_fragment_length
and len(batches) < max_batches): and len(batches) < max_batches):
batch = self.input_reader.next() batch = self.input_reader.next()
steps_so_far += batch.count steps_so_far += batch.count if \
self.count_steps_by == "env_steps" else \
batch.agent_steps()
batches.append(batch) batches.append(batch)
batch = batches[0].concat_samples(batches) if len(batches) > 1 else \ batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
batches[0] batches[0]

View file

@ -129,6 +129,7 @@ class SyncSampler(SamplerInput):
obs_filters: Dict[PolicyID, Filter], obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool, clip_rewards: bool,
rollout_fragment_length: int, rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks", callbacks: "DefaultCallbacks",
horizon: int = None, horizon: int = None,
multiple_episodes_in_batch: bool = False, multiple_episodes_in_batch: bool = False,
@ -190,8 +191,12 @@ class SyncSampler(SamplerInput):
self.perf_stats = _PerfStats() self.perf_stats = _PerfStats()
if _use_trajectory_view_api: if _use_trajectory_view_api:
self.sample_collector = _SimpleListCollector( self.sample_collector = _SimpleListCollector(
policies, clip_rewards, callbacks, multiple_episodes_in_batch, policies,
rollout_fragment_length) clip_rewards,
callbacks,
multiple_episodes_in_batch,
rollout_fragment_length,
count_steps_by=count_steps_by)
else: else:
self.sample_collector = None self.sample_collector = None
@ -254,6 +259,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
obs_filters: Dict[PolicyID, Filter], obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool, clip_rewards: bool,
rollout_fragment_length: int, rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks", callbacks: "DefaultCallbacks",
horizon: int = None, horizon: int = None,
multiple_episodes_in_batch: bool = False, multiple_episodes_in_batch: bool = False,
@ -282,6 +288,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
rollout_fragment_length (int): The length of a fragment to collect rollout_fragment_length (int): The length of a fragment to collect
before building a SampleBatch from the data and resetting before building a SampleBatch from the data and resetting
the SampleBatchBuilder object. the SampleBatchBuilder object.
count_steps_by (str): Either "env_steps" or "agent_steps".
Refers to the unit of `rollout_fragment_length`.
callbacks (Callbacks): The Callbacks object to use when episode callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout. events happen during rollout.
horizon (Optional[int]): Hard-reset the Env horizon (Optional[int]): Hard-reset the Env
@ -336,8 +344,12 @@ class AsyncSampler(threading.Thread, SamplerInput):
self._use_trajectory_view_api = _use_trajectory_view_api self._use_trajectory_view_api = _use_trajectory_view_api
if _use_trajectory_view_api: if _use_trajectory_view_api:
self.sample_collector = _SimpleListCollector( self.sample_collector = _SimpleListCollector(
policies, clip_rewards, callbacks, multiple_episodes_in_batch, policies,
rollout_fragment_length) clip_rewards,
callbacks,
multiple_episodes_in_batch,
rollout_fragment_length,
count_steps_by=count_steps_by)
else: else:
self.sample_collector = None self.sample_collector = None

View file

View file

@ -1,5 +1,6 @@
from collections import Counter from collections import Counter
import gym import gym
from gym.spaces import Box, Discrete
import numpy as np import numpy as np
import os import os
import random import random
@ -13,9 +14,12 @@ from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
SampleBatch
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.test_utils import check, framework_iterator from ray.rllib.utils.test_utils import check, framework_iterator
from ray.tune.registry import register_env from ray.tune.registry import register_env
@ -71,39 +75,6 @@ class FailOnStepEnv(gym.Env):
raise ValueError("kaboom") raise ValueError("kaboom")
class MockEnv(gym.Env):
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return 0, 1, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100, self.i >= self.episode_length, {}
class MockVectorEnv(VectorEnv): class MockVectorEnv(VectorEnv):
def __init__(self, episode_length, num_envs): def __init__(self, episode_length, num_envs):
super().__init__( super().__init__(
@ -523,14 +494,57 @@ class TestRolloutWorker(unittest.TestCase):
ev.stop() ev.stop()
def test_truncate_episodes(self): def test_truncate_episodes(self):
ev = RolloutWorker( ev_env_steps = RolloutWorker(
env_creator=lambda _: MockEnv(10), env_creator=lambda _: MockEnv(10),
policy_spec=MockPolicy, policy_spec=MockPolicy,
policy_config={"_use_trajectory_view_api": True},
rollout_fragment_length=15, rollout_fragment_length=15,
batch_mode="truncate_episodes") batch_mode="truncate_episodes")
batch = ev.sample() batch = ev_env_steps.sample()
self.assertEqual(batch.count, 15) self.assertEqual(batch.count, 15)
ev.stop() self.assertTrue(isinstance(batch, SampleBatch))
ev_env_steps.stop()
action_space = Discrete(2)
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
ev_agent_steps = RolloutWorker(
env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
policy_spec={
"pol0": (MockPolicy, obs_space, action_space, {}),
"pol1": (MockPolicy, obs_space, action_space, {}),
},
policy_config={"_use_trajectory_view_api": True},
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
rollout_fragment_length=301,
count_steps_by="env_steps",
batch_mode="truncate_episodes",
)
batch = ev_agent_steps.sample()
self.assertTrue(isinstance(batch, MultiAgentBatch))
self.assertGreater(batch.agent_steps(), 301)
self.assertEqual(batch.env_steps(), 301)
ev_agent_steps.stop()
ev_agent_steps = RolloutWorker(
env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
policy_spec={
"pol0": (MockPolicy, obs_space, action_space, {}),
"pol1": (MockPolicy, obs_space, action_space, {}),
},
policy_config={"_use_trajectory_view_api": True},
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
rollout_fragment_length=301,
count_steps_by="agent_steps",
batch_mode="truncate_episodes")
batch = ev_agent_steps.sample()
self.assertTrue(isinstance(batch, MultiAgentBatch))
self.assertLess(batch.env_steps(), 301)
# When counting agent steps, the count may be slightly larger than
# rollout_fragment_length, b/c we have up to N agents stepping in each
# env step and we only check, whether we should build after each env
# step.
self.assertGreaterEqual(batch.agent_steps(), 301)
ev_agent_steps.stop()
def test_complete_episodes(self): def test_complete_episodes(self):
ev = RolloutWorker( ev = RolloutWorker(

View file

@ -1,13 +1,16 @@
import copy import copy
import gym import gym
from gym.spaces import Box, Discrete from gym.spaces import Box, Discrete
import numpy as np
import time import time
import unittest import unittest
import ray import ray
from ray import tune
import ray.rllib.agents.dqn as dqn import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.ppo as ppo import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \ from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwareLSTMPolicy EpisodeEnvAwareLSTMPolicy
@ -295,6 +298,38 @@ class TestTrajectoryViewAPI(unittest.TestCase):
pol_batch_wo = result.policy_batches["pol0"] pol_batch_wo = result.policy_batches["pol0"]
check(pol_batch_w.data, pol_batch_wo.data) check(pol_batch_w.data, pol_batch_wo.data)
def test_counting_by_agent_steps(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
action_space = Discrete(2)
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
config["num_workers"] = 2
config["num_sgd_iter"] = 2
config["framework"] = "torch"
config["rollout_fragment_length"] = 21
config["train_batch_size"] = 147
config["multiagent"] = {
"policies": {
"p0": (None, obs_space, action_space, {}),
"p1": (None, obs_space, action_space, {}),
},
"policy_mapping_fn": lambda aid: "p{}".format(aid),
"count_steps_by": "agent_steps",
}
tune.register_env(
"ma_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}))
num_iterations = 2
trainer = ppo.PPOTrainer(config=config, env="ma_cartpole")
results = None
for i in range(num_iterations):
results = trainer.train()
self.assertGreater(results["timesteps_total"],
num_iterations * config["train_batch_size"])
self.assertLess(results["timesteps_total"],
(num_iterations + 1) * config["train_batch_size"])
trainer.stop()
def analyze_rnn_batch(batch, max_seq_len): def analyze_rnn_batch(batch, max_seq_len):
count = batch.count count = batch.count

View file

@ -321,6 +321,7 @@ class WorkerSet:
tf_session_creator=(session_creator tf_session_creator=(session_creator
if config["tf_session_args"] else None), if config["tf_session_args"] else None),
rollout_fragment_length=config["rollout_fragment_length"], rollout_fragment_length=config["rollout_fragment_length"],
count_steps_by=config["multiagent"]["count_steps_by"],
batch_mode=config["batch_mode"], batch_mode=config["batch_mode"],
episode_horizon=config["horizon"], episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"], preprocessor_pref=config["preprocessor_pref"],

46
rllib/examples/env/mock_env.py vendored Normal file
View file

@ -0,0 +1,46 @@
import gym
class MockEnv(gym.Env):
"""Mock environment for testing purposes.
Observation=0, reward=1.0, episode-len is configurable.
Actions are ignored.
"""
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return 0
def step(self, action):
self.i += 1
return 0, 1.0, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
"""Mock environment for testing purposes.
Observation=ts (discrete space!), reward=100.0, episode-len is
configurable. Actions are ignored.
"""
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100.0, self.i >= self.episode_length, {}

View file

@ -1,8 +1,8 @@
import gym import gym
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2
def make_multiagent(env_name_or_creator): def make_multiagent(env_name_or_creator):

View file

@ -81,7 +81,8 @@ def custom_training_workflow(workers: WorkerSet, config: dict):
# PPO sub-flow. # PPO sub-flow.
ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \ ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \
.combine(ConcatBatches(min_batch_size=200)) \ .combine(ConcatBatches(
min_batch_size=200, count_steps_by="env_steps")) \
.for_each(add_ppo_metrics) \ .for_each(add_ppo_metrics) \
.for_each(StandardizeFields(["advantages"])) \ .for_each(StandardizeFields(["advantages"])) \
.for_each(TrainOneStep( .for_each(TrainOneStep(

View file

@ -141,13 +141,15 @@ class ConcatBatches:
Examples: Examples:
>>> rollouts = ParallelRollouts(...) >>> rollouts = ParallelRollouts(...)
>>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000)) >>> rollouts = rollouts.combine(ConcatBatches(
... min_batch_size=10000, count_steps_by="env_steps"))
>>> print(next(rollouts).count) >>> print(next(rollouts).count)
10000 10000
""" """
def __init__(self, min_batch_size: int): def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"):
self.min_batch_size = min_batch_size self.min_batch_size = min_batch_size
self.count_steps_by = count_steps_by
self.buffer = [] self.buffer = []
self.count = 0 self.count = 0
self.batch_start_time = None self.batch_start_time = None
@ -159,7 +161,15 @@ class ConcatBatches:
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
_check_sample_batch_type(batch) _check_sample_batch_type(batch)
self.buffer.append(batch) self.buffer.append(batch)
self.count += batch.count
if self.count_steps_by == "env_steps":
self.count += batch.count
else:
assert isinstance(batch, MultiAgentBatch), \
"`count_steps_by=agent_steps` only allowed in multi-agent " \
"environments!"
self.count += batch.agent_steps()
if self.count >= self.min_batch_size: if self.count >= self.min_batch_size:
if self.count > self.min_batch_size * 2: if self.count > self.min_batch_size * 2:
logger.info("Collected more training samples than expected " logger.info("Collected more training samples than expected "

View file

@ -51,7 +51,9 @@ class Aggregator(ParallelIteratorWorker):
.flatten() \ .flatten() \
.combine( .combine(
ConcatBatches( ConcatBatches(
min_batch_size=config["train_batch_size"])) min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
for train_batch in it: for train_batch in it:
yield train_batch yield train_batch

View file

@ -417,16 +417,17 @@ class MultiAgentBatch:
Args: Args:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences. ids to SampleBatches of experiences.
env_steps (int): The number of timesteps in the environment this env_steps (int): The number of environment steps in the environment
batch contains. This will be less than the number of this batch contains. This will be less than the number of
transitions this batch contains across all policies in total. transitions this batch contains across all policies in total.
""" """
for v in policy_batches.values(): for v in policy_batches.values():
assert isinstance(v, SampleBatch) assert isinstance(v, SampleBatch)
self.policy_batches = policy_batches self.policy_batches = policy_batches
# Called count for uniformity with SampleBatch. Prefer to access this # Called "count" for uniformity with SampleBatch.
# via the env_steps() method when possible for clarity. # Prefer to access this via the `env_steps()` method when possible
# for clarity.
self.count = env_steps self.count = env_steps
@PublicAPI @PublicAPI
@ -526,7 +527,8 @@ class MultiAgentBatch:
""" """
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches: if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
return policy_batches[DEFAULT_POLICY_ID] return policy_batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(policy_batches, env_steps) return MultiAgentBatch(
policy_batches=policy_batches, env_steps=env_steps)
@staticmethod @staticmethod
@PublicAPI @PublicAPI

View file

@ -9,8 +9,9 @@ from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy, from ray.rllib.evaluation.tests.test_rollout_worker import (BadPolicy,
MockEnv) MockPolicy)
from ray.rllib.examples.env.mock_env import MockEnv
from ray.rllib.utils.test_utils import framework_iterator from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import register_env from ray.tune.registry import register_env

View file

@ -5,8 +5,8 @@ import unittest
import ray import ray
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
from ray.rllib.examples.env.multi_agent import BasicMultiAgent from ray.rllib.examples.env.multi_agent import BasicMultiAgent
from ray.rllib.tests.test_rollout_worker import MockPolicy
from ray.rllib.tests.test_external_env import make_simple_serving from ray.rllib.tests.test_external_env import make_simple_serving
SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv) SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)

View file

@ -12,8 +12,8 @@ from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent
from ray.rllib.tests.test_rollout_worker import MockPolicy
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.numpy import one_hot
from ray.rllib.utils.test_utils import check from ray.rllib.utils.test_utils import check

View file

@ -4,7 +4,7 @@ import unittest
import ray import ray
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.tests.test_rollout_worker import MockPolicy from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
class TestPerf(unittest.TestCase): class TestPerf(unittest.TestCase):