From e40b14d2558aa9e586e0c5a12f19522d23019364 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 9 Dec 2020 01:41:45 +0100 Subject: [PATCH] [RLlib] Batch-size for truncate_episode batch_mode should be confgurable in agent-steps (rather than env-steps), if needed. (#12420) --- rllib/BUILD | 14 +-- rllib/agents/a3c/a2c.py | 13 +- rllib/agents/impala/impala.py | 5 +- rllib/agents/marwil/marwil.py | 5 +- rllib/agents/ppo/ppo.py | 5 +- rllib/agents/ppo/tests/test_ppo.py | 2 +- rllib/agents/qmix/qmix.py | 5 +- rllib/agents/trainer.py | 37 +++++- rllib/agents/trainer_template.py | 9 +- .../alpha_zero/core/alpha_zero_trainer.py | 16 ++- .../evaluation/collectors/sample_collector.py | 32 ++++- .../collectors/simple_list_collector.py | 116 ++++++++++++------ rllib/evaluation/rollout_worker.py | 19 ++- rllib/evaluation/sampler.py | 20 ++- rllib/evaluation/tests/__init__.py | 0 .../tests/test_rollout_worker.py | 88 +++++++------ .../tests/test_trajectory_view_api.py | 35 ++++++ rllib/evaluation/worker_set.py | 1 + rllib/examples/env/mock_env.py | 46 +++++++ rllib/examples/env/multi_agent.py | 2 +- rllib/examples/two_trainer_workflow.py | 3 +- rllib/execution/rollout_ops.py | 16 ++- rllib/execution/tree_agg.py | 4 +- rllib/policy/sample_batch.py | 12 +- rllib/tests/test_external_env.py | 5 +- rllib/tests/test_external_multi_agent_env.py | 2 +- rllib/tests/test_multi_agent_env.py | 2 +- rllib/tests/test_perf.py | 2 +- 28 files changed, 384 insertions(+), 132 deletions(-) create mode 100644 rllib/evaluation/tests/__init__.py rename rllib/{ => evaluation}/tests/test_rollout_worker.py (90%) create mode 100644 rllib/examples/env/mock_env.py diff --git a/rllib/BUILD b/rllib/BUILD index 8af609982..daed4b8f7 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1108,6 +1108,13 @@ py_test( # srcs = ["evaluation/tests/test_trajectory_view_api.py"] #) +py_test( + name = "evaluation/tests/test_rollout_worker", + tags = ["evaluation"], + size = "medium", + srcs = ["evaluation/tests/test_rollout_worker.py"] +) + # -------------------------------------------------------------------- # Optimizers and Memories @@ -1411,13 +1418,6 @@ py_test( args = ["TestRolloutLearntPolicy"] ) -py_test( - name = "tests/test_rollout_worker", - tags = ["tests_dir", "tests_dir_R"], - size = "medium", - srcs = ["tests/test_rollout_worker.py"] -) - py_test( name = "tests/test_supported_multi_agent_pg", main = "tests/test_supported_multi_agent.py", diff --git a/rllib/agents/a3c/a2c.py b/rllib/agents/a3c/a2c.py index 0a71a359c..e6ea0c356 100644 --- a/rllib/agents/a3c/a2c.py +++ b/rllib/agents/a3c/a2c.py @@ -38,17 +38,20 @@ def execution_plan(workers, config): # allowing for extremely large experience batches to be used. train_op = ( rollouts.combine( - ConcatBatches(min_batch_size=config["microbatch_size"])) + ConcatBatches( + min_batch_size=config["microbatch_size"], + count_steps_by=config["multiagent"]["count_steps_by"])) .for_each(ComputeGradients(workers)) # (grads, info) .batch(num_microbatches) # List[(grads, info)] .for_each(AverageGradients()) # (avg_grads, info) .for_each(ApplyGradients(workers))) else: # In normal mode, we execute one SGD step per each train batch. - train_op = rollouts \ - .combine(ConcatBatches( - min_batch_size=config["train_batch_size"])) \ - .for_each(TrainOneStep(workers)) + train_op = rollouts.combine( + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"][ + "count_steps_by"])).for_each(TrainOneStep(workers)) return StandardMetricsReporting(train_op, workers, config) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index 7a09b1f9a..0bcddf4f3 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -221,7 +221,10 @@ def gather_experiences_directly(workers, config): replay_proportion=config["replay_proportion"])) \ .flatten() \ .combine( - ConcatBatches(min_batch_size=config["train_batch_size"])) + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) return train_batches diff --git a/rllib/agents/marwil/marwil.py b/rllib/agents/marwil/marwil.py index 6aeb373c5..c4f88fdb8 100644 --- a/rllib/agents/marwil/marwil.py +++ b/rllib/agents/marwil/marwil.py @@ -56,7 +56,10 @@ def execution_plan(workers, config): replay_op = Replay(local_buffer=replay_buffer) \ .combine( - ConcatBatches(min_batch_size=config["train_batch_size"])) \ + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) \ .for_each(TrainOneStep(workers)) train_op = Concurrently( diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index 8e988d7e5..026988201 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -244,7 +244,10 @@ def execution_plan(workers: WorkerSet, SelectExperiences(workers.trainable_policies())) # Concatenate the SampleBatches into one. rollouts = rollouts.combine( - ConcatBatches(min_batch_size=config["train_batch_size"])) + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) # Standardize advantages. rollouts = rollouts.for_each(StandardizeFields(["advantages"])) diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index c00cd36ba..b4259c144 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -73,7 +73,7 @@ class TestPPO(unittest.TestCase): def test_ppo_compilation_and_lr_schedule(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) - # for checking lr-schedule correctness + # For checking lr-schedule correctness. config["callbacks"] = MyCallbacks config["num_workers"] = 1 diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 7d64680f5..c2584378d 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -109,7 +109,10 @@ def execution_plan(workers, config): train_op = Replay(local_buffer=replay_buffer) \ .combine( - ConcatBatches(min_batch_size=config["train_batch_size"])) \ + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"] + )) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index b5751e264..c57ff0b67 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -75,10 +75,18 @@ COMMON_CONFIG: TrainerConfigDict = { # The dataflow here can vary per algorithm. For example, PPO further # divides the train batch into minibatches for multi-epoch SGD. "rollout_fragment_length": 200, - # Whether to rollout "complete_episodes" or "truncate_episodes" to - # `rollout_fragment_length` length unrolls. Episode truncation guarantees - # evenly sized batches, but increases variance as the reward-to-go will - # need to be estimated at truncation boundaries. + # How to build per-Sampler (RolloutWorker) batches, which are then + # usually concat'd to form the train batch. Note that "steps" below can + # mean different things (either env- or agent-steps) and depends on the + # `count_steps_by` (multiagent) setting below. + # truncate_episodes: Each produced batch (when calling + # RolloutWorker.sample()) will contain exactly `rollout_fragment_length` + # steps. This mode guarantees evenly sized batches, but increases + # variance as the future return must now be estimated at truncation + # boundaries. + # complete_episodes: Each unroll happens exactly over one episode, from + # beginning to end. Data collection will not stop unless the episode + # terminates or a configured horizon (hard or soft) is hit. "batch_mode": "truncate_episodes", # === Settings for the Trainer process === @@ -357,6 +365,13 @@ COMMON_CONFIG: TrainerConfigDict = { # agents it controls at that timestep. When replay_mode=independent, # transitions are replayed independently per policy. "replay_mode": "independent", + # Which metric to use as the "batch size" when building a + # MultiAgentBatch. The two supported values are: + # env_steps: Count each time the env is "stepped" (no matter how many + # multi-agent actions are passed/how many multi-agent observations + # have been returned in the previous step). + # agent_steps: Count each individual agent step as one step. + "count_steps_by": "env_steps", }, # === Logger === @@ -1081,6 +1096,20 @@ class Trainer(Trainable): config["model"]["lstm_use_prev_action"] = prev_a_r config["model"]["lstm_use_prev_reward"] = prev_a_r + # Check batching/sample collection settings. + if config["batch_mode"] not in [ + "truncate_episodes", "complete_episodes" + ]: + raise ValueError("`batch_mode` must be one of [truncate_episodes|" + "complete_episodes]! Got {}".format( + config["batch_mode"])) + + if config["multiagent"].get("count_steps_by", "env_steps") not in \ + ["env_steps", "agent_steps"]: + raise ValueError( + "`count_steps_by` must be one of [env_steps|agent_steps]! " + "Got {}".format(config["multiagent"]["count_steps_by"])) + def _try_recover(self): """Try to identify and remove any unhealthy workers. diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index f3c5d4c1c..cd622631a 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -22,10 +22,11 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict): # Combine experiences batches until we hit `train_batch_size` in size. # Then, train the policy on those experiences and update the workers. - train_op = rollouts \ - .combine(ConcatBatches( - min_batch_size=config["train_batch_size"])) \ - .for_each(TrainOneStep(workers)) + train_op = rollouts.combine( + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )).for_each(TrainOneStep(workers)) # Add on the standard episode reward, etc. metrics reporting. This returns # a LocalIterator[metrics_dict] representing metrics for each train step. diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py index 27315108b..4af65eba6 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py @@ -164,11 +164,12 @@ def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["simple_optimizer"]: - train_op = rollouts \ - .combine(ConcatBatches( - min_batch_size=config["train_batch_size"])) \ - .for_each(TrainOneStep( - workers, num_sgd_iter=config["num_sgd_iter"])) + train_op = rollouts.combine( + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )).for_each( + TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"])) else: replay_buffer = SimpleReplayBuffer(config["buffer_size"]) @@ -178,7 +179,10 @@ def execution_plan(workers, config): replay_op = Replay(local_buffer=replay_buffer) \ .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \ .combine( - ConcatBatches(min_batch_size=config["train_batch_size"])) \ + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) \ .for_each(TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"])) diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 7b154f1d8..da188e938 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -110,11 +110,37 @@ class _SampleCollector(metaclass=ABCMeta): @abstractmethod def total_env_steps(self) -> int: - """Returns total number of steps taken in the env (sum of all agents). + """Returns total number of env-steps taken so far. + + Thereby, a step in an N-agent multi-agent environment counts as only 1 + for this metric. The returned count contains everything that has not + been built yet (and returned as MultiAgentBatches by the + `try_build_truncated_episode_multi_agent_batch` or + `postprocess_episode(build=True)` methods). After such build, this + counter is reset to 0. Returns: - int: The number of steps taken in total in the environment over all - agents. + int: The number of env-steps taken in total in the environment(s) + so far. + """ + raise NotImplementedError + + @abstractmethod + def total_agent_steps(self) -> int: + """Returns total number of (individual) agent-steps taken so far. + + Thereby, a step in an N-agent multi-agent environment counts as N. + If less than N agents have stepped (because some agents were not + required to send actions), the count will be increased by less than N. + The returned count contains everything that has not been built yet + (and returned as MultiAgentBatches by the + `try_build_truncated_episode_multi_agent_batch` or + `postprocess_episode(build=True)` methods). After such build, this + counter is reset to 0. + + Returns: + int: The number of (individual) agent-steps taken in total in the + environment(s) so far. """ raise NotImplementedError diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index a5ef0fc9f..efcadf32f 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -51,7 +51,7 @@ class _AgentCollector: self.episode_id = None # The simple timestep count for this agent. Gets increased by one # each time a (non-initial!) observation is added. - self.count = 0 + self.agent_steps = 0 def add_init_obs(self, episode_id: EpisodeID, agent_index: int, env_id: EnvID, t: int, init_obs: TensorType) -> None: @@ -105,7 +105,7 @@ class _AgentCollector: if k not in self.buffers: self._build_buffers(single_row=values) self.buffers[k].append(v) - self.count += 1 + self.agent_steps += 1 def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: """Builds a SampleBatch from the thus-far collected agent data. @@ -183,7 +183,7 @@ class _AgentCollector: if self.shift_before > 0: for k, data in self.buffers.items(): self.buffers[k] = data[-self.shift_before:] - self.count = 0 + self.agent_steps = 0 return batch @@ -238,7 +238,7 @@ class _PolicyCollector: # NOTE: This is not an env-step count (across n agents). AgentA and # agentB, both using this policy, acting in the same episode and both # doing n steps would increase the count by 2*n. - self.count = 0 + self.agent_steps = 0 def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -246,9 +246,9 @@ class _PolicyCollector: """Adds a postprocessed SampleBatch (single agent) to our buffers. Args: - batch (SampleBatch): A single agent (one trajectory) SampleBatch - to be added to the Policy's buffers. - view_requirements (DViewRequirementsDict): The view + batch (SampleBatch): An individual agent's (one trajectory) + SampleBatch to be added to the Policy's buffers. + view_requirements (ViewRequirementsDict): The view requirements for the policy. This is so we know, whether a view-column needs to be copied at all (not needed for training). @@ -261,7 +261,7 @@ class _PolicyCollector: view_requirements[view_col].used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. - self.count += batch.count + self.agent_steps += batch.count def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -277,8 +277,8 @@ class _PolicyCollector: assert SampleBatch.UNROLL_ID in batch.data # Clear buffers for future samples. self.buffers.clear() - # Reset count to 0. - self.count = 0 + # Reset agent steps to 0. + self.agent_steps = 0 return batch @@ -288,7 +288,11 @@ class _PolicyCollectorGroup: pid: _PolicyCollector() for pid in policy_map.keys() } - self.count = 0 + # Total env-steps (1 env-step=up to N agents stepped). + self.env_steps = 0 + # Total agent steps (1 agent-step=1 individual agent (out of N) + # stepped). + self.agent_steps = 0 class _SimpleListCollector(_SampleCollector): @@ -305,7 +309,8 @@ class _SimpleListCollector(_SampleCollector): clip_rewards: Union[bool, float], callbacks: "DefaultCallbacks", multiple_episodes_in_batch: bool = True, - rollout_fragment_length: int = 200): + rollout_fragment_length: int = 200, + count_steps_by: str = "env_steps"): """Initializes a _SimpleListCollector instance. Args: @@ -314,6 +319,10 @@ class _SimpleListCollector(_SampleCollector): clip_rewards (Union[bool, float]): Whether to clip rewards before postprocessing (at +/-1.0) or the actual value to +/- clip. callbacks (DefaultCallbacks): RLlib callbacks. + multiple_episodes_in_batch (bool): Whether it's allowed to pack + multiple episodes into the same built batch. + rollout_fragment_length (int): The + """ self.policy_map = policy_map @@ -321,6 +330,7 @@ class _SimpleListCollector(_SampleCollector): self.callbacks = callbacks self.multiple_episodes_in_batch = multiple_episodes_in_batch self.rollout_fragment_length = rollout_fragment_length + self.count_steps_by = count_steps_by self.large_batch_threshold: int = max( 1000, rollout_fragment_length * 10) if rollout_fragment_length != float("inf") else 5000 @@ -340,8 +350,10 @@ class _SimpleListCollector(_SampleCollector): self.forward_pass_size = {pid: 0 for pid in policy_map.keys()} # Maps episode ID to the (non-built) env steps taken in this episode. - self.episode_steps: Dict[EpisodeID, int] = \ - collections.defaultdict(int) + self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int) + # Maps episode ID to the (non-built) individual agent steps in this + # episode. + self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int) # Maps episode ID to MultiAgentEpisode. self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {} @@ -351,15 +363,17 @@ class _SimpleListCollector(_SampleCollector): self.episode_steps[episode_id] += 1 episode.length += 1 assert episode.batch_builder is not None - env_steps = episode.batch_builder.count - num_observations = sum( - c.count for c in episode.batch_builder.policy_collectors.values()) + env_steps = episode.batch_builder.env_steps + num_individual_observations = sum( + c.agent_steps + for c in episode.batch_builder.policy_collectors.values()) - if num_observations > self.large_batch_threshold and \ + if num_individual_observations > self.large_batch_threshold and \ log_once("large_batch_warning"): logger.warning( "More than {} observations in {} env steps for " - "episode {} ".format(num_observations, env_steps, episode_id) + + "episode {} ".format(num_individual_observations, env_steps, + episode_id) + "are buffered in the sampler. If this is more than you " "expected, check that that you set a horizon on your " "environment correctly and that it terminates at some point. " @@ -412,6 +426,8 @@ class _SimpleListCollector(_SampleCollector): assert self.agent_key_to_policy_id[agent_key] == policy_id assert agent_key in self.agent_collectors + self.agent_steps[episode_id] += 1 + # Include the current agent id for multi-agent algorithms. if agent_id != _DUMMY_AGENT_ID: values["agent_id"] = agent_id @@ -424,7 +440,18 @@ class _SimpleListCollector(_SampleCollector): @override(_SampleCollector) def total_env_steps(self) -> int: - return sum(a.count for a in self.agent_collectors.values()) + # Add the non-built ongoing-episode env steps + the already built + # env-steps. + return sum(self.episode_steps.values()) + sum( + pg.env_steps for pg in self.policy_collector_groups.values()) + + @override(_SampleCollector) + def total_agent_steps(self) -> int: + # Add the non-built ongoing-episode agent steps (still in the agent + # collectors) + the already built agent steps. + return sum(a.agent_steps for a in self.agent_collectors.values()) + \ + sum(pg.agent_steps for pg in + self.policy_collector_groups.values()) @override(_SampleCollector) def get_inference_input_dict(self, policy_id: PolicyID) -> \ @@ -463,11 +490,12 @@ class _SimpleListCollector(_SampleCollector): return input_dict @override(_SampleCollector) - def postprocess_episode(self, - episode: MultiAgentEpisode, - is_done: bool = False, - check_dones: bool = False, - build: bool = False) -> None: + def postprocess_episode( + self, + episode: MultiAgentEpisode, + is_done: bool = False, + check_dones: bool = False, + build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]: episode_id = episode.episode_id policy_collector_group = episode.batch_builder @@ -478,7 +506,7 @@ class _SimpleListCollector(_SampleCollector): pre_batches = {} for (eps_id, agent_id), collector in self.agent_collectors.items(): # Build only if there is data and agent is part of given episode. - if collector.count == 0 or eps_id != episode_id: + if collector.agent_steps == 0 or eps_id != episode_id: continue pid = self.agent_key_to_policy_id[(eps_id, agent_id)] policy = self.policy_map[pid] @@ -559,16 +587,19 @@ class _SimpleListCollector(_SampleCollector): post_batch, policy.view_requirements) env_steps = self.episode_steps[episode_id] - policy_collector_group.count += env_steps + policy_collector_group.env_steps += env_steps + agent_steps = self.agent_steps[episode_id] + policy_collector_group.agent_steps += agent_steps if is_done: del self.episode_steps[episode_id] + del self.agent_steps[episode_id] del self.episodes[episode_id] # Make PolicyCollectorGroup available for more agent batches in # other episodes. Do not reset count to 0. self.policy_collector_groups.append(policy_collector_group) else: - self.episode_steps[episode_id] = 0 + self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0 # Build a MultiAgentBatch from the episode and return. if build: @@ -579,14 +610,15 @@ class _SimpleListCollector(_SampleCollector): ma_batch = {} for pid, collector in episode.batch_builder.policy_collectors.items(): - if collector.count > 0: + if collector.agent_steps > 0: ma_batch[pid] = collector.build() # Create the batch. ma_batch = MultiAgentBatch.wrap_as_needed( - ma_batch, env_steps=episode.batch_builder.count) + ma_batch, env_steps=episode.batch_builder.env_steps) # PolicyCollectorGroup is empty. - episode.batch_builder.count = 0 + episode.batch_builder.env_steps = 0 + episode.batch_builder.agent_steps = 0 return ma_batch @@ -595,16 +627,26 @@ class _SimpleListCollector(_SampleCollector): List[Union[MultiAgentBatch, SampleBatch]]: batches = [] # Loop through ongoing episodes and see whether their length plus - # what's already in the policy collectors reaches the fragment-len. + # what's already in the policy collectors reaches the fragment-len + # (abiding to the unit used: env-steps or agent-steps). for episode_id, episode in self.episodes.items(): - env_steps = episode.batch_builder.count + \ - self.episode_steps[episode_id] + # Measure batch size in env-steps. + if self.count_steps_by == "env_steps": + built_steps = episode.batch_builder.env_steps + ongoing_steps = self.episode_steps[episode_id] + # Measure batch-size in agent-steps. + else: + built_steps = episode.batch_builder.agent_steps + ongoing_steps = self.agent_steps[episode_id] + # Reached the fragment-len -> We should build an MA-Batch. - if env_steps >= self.rollout_fragment_length: - assert env_steps == self.rollout_fragment_length + if built_steps + ongoing_steps >= self.rollout_fragment_length: + if self.count_steps_by != "agent_steps": + assert built_steps + ongoing_steps == \ + self.rollout_fragment_length # If we reached the fragment-len only because of `episode_id` # (still ongoing) -> postprocess `episode_id` first. - if episode.batch_builder.count < self.rollout_fragment_length: + if built_steps < self.rollout_fragment_length: self.postprocess_episode(episode, is_done=False) # Build the MA-batch and return. batch = self._build_multi_agent_batch(episode=episode) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 84b1bb0b2..1579ea0b4 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -143,6 +143,7 @@ class RolloutWorker(ParallelIteratorWorker): policies_to_train: Optional[List[PolicyID]] = None, tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None, rollout_fragment_length: int = 100, + count_steps_by: str = "env_steps", batch_mode: str = "truncate_episodes", episode_horizon: int = None, preprocessor_pref: str = "deepmind", @@ -208,8 +209,11 @@ class RolloutWorker(ParallelIteratorWorker): tf_session_creator (Optional[Callable[[], tf1.Session]]): A function that returns a TF session. This is optional and only useful with TFPolicy. - rollout_fragment_length (int): The target number of env transitions - to include in each sample batch returned from this worker. + rollout_fragment_length (int): The target number of steps + (maesured in `count_steps_by`) to include in each sample + batch returned from this worker. + count_steps_by (str): The unit in which to count fragment + lengths. One of env_steps or agent_steps. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch of at most `rollout_fragment_length * num_envs` in size. @@ -356,6 +360,7 @@ class RolloutWorker(ParallelIteratorWorker): raise ValueError("Policy mapping function not callable?") self.env_creator: Callable[[EnvContext], EnvType] = env_creator self.rollout_fragment_length: int = rollout_fragment_length * num_envs + self.count_steps_by: str = count_steps_by self.batch_mode: str = batch_mode self.compress_observations: bool = compress_observations self.preprocessing_enabled: bool = True @@ -570,6 +575,7 @@ class RolloutWorker(ParallelIteratorWorker): obs_filters=self.filters, clip_rewards=clip_rewards, rollout_fragment_length=rollout_fragment_length, + count_steps_by=count_steps_by, callbacks=self.callbacks, horizon=episode_horizon, multiple_episodes_in_batch=pack, @@ -593,6 +599,7 @@ class RolloutWorker(ParallelIteratorWorker): obs_filters=self.filters, clip_rewards=clip_rewards, rollout_fragment_length=rollout_fragment_length, + count_steps_by=count_steps_by, callbacks=self.callbacks, horizon=episode_horizon, multiple_episodes_in_batch=pack, @@ -636,7 +643,9 @@ class RolloutWorker(ParallelIteratorWorker): self.rollout_fragment_length)) batches = [self.input_reader.next()] - steps_so_far = batches[0].count + steps_so_far = batches[0].count if \ + self.count_steps_by == "env_steps" else \ + batches[0].agent_steps() # In truncate_episodes mode, never pull more than 1 batch per env. # This avoids over-running the target batch size. @@ -648,7 +657,9 @@ class RolloutWorker(ParallelIteratorWorker): while (steps_so_far < self.rollout_fragment_length and len(batches) < max_batches): batch = self.input_reader.next() - steps_so_far += batch.count + steps_so_far += batch.count if \ + self.count_steps_by == "env_steps" else \ + batch.agent_steps() batches.append(batch) batch = batches[0].concat_samples(batches) if len(batches) > 1 else \ batches[0] diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 5703f3b15..a115a0149 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -129,6 +129,7 @@ class SyncSampler(SamplerInput): obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, rollout_fragment_length: int, + count_steps_by: str = "env_steps", callbacks: "DefaultCallbacks", horizon: int = None, multiple_episodes_in_batch: bool = False, @@ -190,8 +191,12 @@ class SyncSampler(SamplerInput): self.perf_stats = _PerfStats() if _use_trajectory_view_api: self.sample_collector = _SimpleListCollector( - policies, clip_rewards, callbacks, multiple_episodes_in_batch, - rollout_fragment_length) + policies, + clip_rewards, + callbacks, + multiple_episodes_in_batch, + rollout_fragment_length, + count_steps_by=count_steps_by) else: self.sample_collector = None @@ -254,6 +259,7 @@ class AsyncSampler(threading.Thread, SamplerInput): obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, rollout_fragment_length: int, + count_steps_by: str = "env_steps", callbacks: "DefaultCallbacks", horizon: int = None, multiple_episodes_in_batch: bool = False, @@ -282,6 +288,8 @@ class AsyncSampler(threading.Thread, SamplerInput): rollout_fragment_length (int): The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. + count_steps_by (str): Either "env_steps" or "agent_steps". + Refers to the unit of `rollout_fragment_length`. callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env @@ -336,8 +344,12 @@ class AsyncSampler(threading.Thread, SamplerInput): self._use_trajectory_view_api = _use_trajectory_view_api if _use_trajectory_view_api: self.sample_collector = _SimpleListCollector( - policies, clip_rewards, callbacks, multiple_episodes_in_batch, - rollout_fragment_length) + policies, + clip_rewards, + callbacks, + multiple_episodes_in_batch, + rollout_fragment_length, + count_steps_by=count_steps_by) else: self.sample_collector = None diff --git a/rllib/evaluation/tests/__init__.py b/rllib/evaluation/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rllib/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py similarity index 90% rename from rllib/tests/test_rollout_worker.py rename to rllib/evaluation/tests/test_rollout_worker.py index 12b92ad10..8d45f5be6 100644 --- a/rllib/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -1,5 +1,6 @@ from collections import Counter import gym +from gym.spaces import Box, Discrete import numpy as np import os import random @@ -13,9 +14,12 @@ from ray.rllib.env.vector_env import VectorEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2 +from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \ + SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.test_utils import check, framework_iterator from ray.tune.registry import register_env @@ -71,39 +75,6 @@ class FailOnStepEnv(gym.Env): raise ValueError("kaboom") -class MockEnv(gym.Env): - def __init__(self, episode_length, config=None): - self.episode_length = episode_length - self.config = config - self.i = 0 - self.observation_space = gym.spaces.Discrete(1) - self.action_space = gym.spaces.Discrete(2) - - def reset(self): - self.i = 0 - return self.i - - def step(self, action): - self.i += 1 - return 0, 1, self.i >= self.episode_length, {} - - -class MockEnv2(gym.Env): - def __init__(self, episode_length): - self.episode_length = episode_length - self.i = 0 - self.observation_space = gym.spaces.Discrete(100) - self.action_space = gym.spaces.Discrete(2) - - def reset(self): - self.i = 0 - return self.i - - def step(self, action): - self.i += 1 - return self.i, 100, self.i >= self.episode_length, {} - - class MockVectorEnv(VectorEnv): def __init__(self, episode_length, num_envs): super().__init__( @@ -523,14 +494,57 @@ class TestRolloutWorker(unittest.TestCase): ev.stop() def test_truncate_episodes(self): - ev = RolloutWorker( + ev_env_steps = RolloutWorker( env_creator=lambda _: MockEnv(10), policy_spec=MockPolicy, + policy_config={"_use_trajectory_view_api": True}, rollout_fragment_length=15, batch_mode="truncate_episodes") - batch = ev.sample() + batch = ev_env_steps.sample() self.assertEqual(batch.count, 15) - ev.stop() + self.assertTrue(isinstance(batch, SampleBatch)) + ev_env_steps.stop() + + action_space = Discrete(2) + obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32) + ev_agent_steps = RolloutWorker( + env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}), + policy_spec={ + "pol0": (MockPolicy, obs_space, action_space, {}), + "pol1": (MockPolicy, obs_space, action_space, {}), + }, + policy_config={"_use_trajectory_view_api": True}, + policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1", + rollout_fragment_length=301, + count_steps_by="env_steps", + batch_mode="truncate_episodes", + ) + batch = ev_agent_steps.sample() + self.assertTrue(isinstance(batch, MultiAgentBatch)) + self.assertGreater(batch.agent_steps(), 301) + self.assertEqual(batch.env_steps(), 301) + ev_agent_steps.stop() + + ev_agent_steps = RolloutWorker( + env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}), + policy_spec={ + "pol0": (MockPolicy, obs_space, action_space, {}), + "pol1": (MockPolicy, obs_space, action_space, {}), + }, + policy_config={"_use_trajectory_view_api": True}, + policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1", + rollout_fragment_length=301, + count_steps_by="agent_steps", + batch_mode="truncate_episodes") + batch = ev_agent_steps.sample() + self.assertTrue(isinstance(batch, MultiAgentBatch)) + self.assertLess(batch.env_steps(), 301) + # When counting agent steps, the count may be slightly larger than + # rollout_fragment_length, b/c we have up to N agents stepping in each + # env step and we only check, whether we should build after each env + # step. + self.assertGreaterEqual(batch.agent_steps(), 301) + ev_agent_steps.stop() def test_complete_episodes(self): ev = RolloutWorker( diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index bd2488c47..a50978bfd 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -1,13 +1,16 @@ import copy import gym from gym.spaces import Box, Discrete +import numpy as np import time import unittest import ray +from ray import tune import ray.rllib.agents.dqn as dqn import ray.rllib.agents.ppo as ppo from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv +from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ EpisodeEnvAwareLSTMPolicy @@ -295,6 +298,38 @@ class TestTrajectoryViewAPI(unittest.TestCase): pol_batch_wo = result.policy_batches["pol0"] check(pol_batch_w.data, pol_batch_wo.data) + def test_counting_by_agent_steps(self): + """Test whether a PPOTrainer can be built with all frameworks.""" + config = copy.deepcopy(ppo.DEFAULT_CONFIG) + action_space = Discrete(2) + obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32) + + config["num_workers"] = 2 + config["num_sgd_iter"] = 2 + config["framework"] = "torch" + config["rollout_fragment_length"] = 21 + config["train_batch_size"] = 147 + config["multiagent"] = { + "policies": { + "p0": (None, obs_space, action_space, {}), + "p1": (None, obs_space, action_space, {}), + }, + "policy_mapping_fn": lambda aid: "p{}".format(aid), + "count_steps_by": "agent_steps", + } + tune.register_env( + "ma_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) + num_iterations = 2 + trainer = ppo.PPOTrainer(config=config, env="ma_cartpole") + results = None + for i in range(num_iterations): + results = trainer.train() + self.assertGreater(results["timesteps_total"], + num_iterations * config["train_batch_size"]) + self.assertLess(results["timesteps_total"], + (num_iterations + 1) * config["train_batch_size"]) + trainer.stop() + def analyze_rnn_batch(batch, max_seq_len): count = batch.count diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 17cc14af3..16626fb86 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -321,6 +321,7 @@ class WorkerSet: tf_session_creator=(session_creator if config["tf_session_args"] else None), rollout_fragment_length=config["rollout_fragment_length"], + count_steps_by=config["multiagent"]["count_steps_by"], batch_mode=config["batch_mode"], episode_horizon=config["horizon"], preprocessor_pref=config["preprocessor_pref"], diff --git a/rllib/examples/env/mock_env.py b/rllib/examples/env/mock_env.py new file mode 100644 index 000000000..8ddbb9b69 --- /dev/null +++ b/rllib/examples/env/mock_env.py @@ -0,0 +1,46 @@ +import gym + + +class MockEnv(gym.Env): + """Mock environment for testing purposes. + + Observation=0, reward=1.0, episode-len is configurable. + Actions are ignored. + """ + + def __init__(self, episode_length, config=None): + self.episode_length = episode_length + self.config = config + self.i = 0 + self.observation_space = gym.spaces.Discrete(1) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + self.i = 0 + return 0 + + def step(self, action): + self.i += 1 + return 0, 1.0, self.i >= self.episode_length, {} + + +class MockEnv2(gym.Env): + """Mock environment for testing purposes. + + Observation=ts (discrete space!), reward=100.0, episode-len is + configurable. Actions are ignored. + """ + + def __init__(self, episode_length): + self.episode_length = episode_length + self.i = 0 + self.observation_space = gym.spaces.Discrete(100) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + self.i = 0 + return self.i + + def step(self, action): + self.i += 1 + return self.i, 100.0, self.i >= self.episode_length, {} diff --git a/rllib/examples/env/multi_agent.py b/rllib/examples/env/multi_agent.py index 5d4ffe863..096dea205 100644 --- a/rllib/examples/env/multi_agent.py +++ b/rllib/examples/env/multi_agent.py @@ -1,8 +1,8 @@ import gym from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2 from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole -from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2 def make_multiagent(env_name_or_creator): diff --git a/rllib/examples/two_trainer_workflow.py b/rllib/examples/two_trainer_workflow.py index 89e05665f..a87a92405 100644 --- a/rllib/examples/two_trainer_workflow.py +++ b/rllib/examples/two_trainer_workflow.py @@ -81,7 +81,8 @@ def custom_training_workflow(workers: WorkerSet, config: dict): # PPO sub-flow. ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \ - .combine(ConcatBatches(min_batch_size=200)) \ + .combine(ConcatBatches( + min_batch_size=200, count_steps_by="env_steps")) \ .for_each(add_ppo_metrics) \ .for_each(StandardizeFields(["advantages"])) \ .for_each(TrainOneStep( diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 818254c2d..baaa26357 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -141,13 +141,15 @@ class ConcatBatches: Examples: >>> rollouts = ParallelRollouts(...) - >>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000)) + >>> rollouts = rollouts.combine(ConcatBatches( + ... min_batch_size=10000, count_steps_by="env_steps")) >>> print(next(rollouts).count) 10000 """ - def __init__(self, min_batch_size: int): + def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"): self.min_batch_size = min_batch_size + self.count_steps_by = count_steps_by self.buffer = [] self.count = 0 self.batch_start_time = None @@ -159,7 +161,15 @@ class ConcatBatches: def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) - self.count += batch.count + + if self.count_steps_by == "env_steps": + self.count += batch.count + else: + assert isinstance(batch, MultiAgentBatch), \ + "`count_steps_by=agent_steps` only allowed in multi-agent " \ + "environments!" + self.count += batch.agent_steps() + if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index 344a22e20..69e06a4b0 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -51,7 +51,9 @@ class Aggregator(ParallelIteratorWorker): .flatten() \ .combine( ConcatBatches( - min_batch_size=config["train_batch_size"])) + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) for train_batch in it: yield train_batch diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index dc36271a8..a2934fdb9 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -417,16 +417,17 @@ class MultiAgentBatch: Args: policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy ids to SampleBatches of experiences. - env_steps (int): The number of timesteps in the environment this - batch contains. This will be less than the number of + env_steps (int): The number of environment steps in the environment + this batch contains. This will be less than the number of transitions this batch contains across all policies in total. """ for v in policy_batches.values(): assert isinstance(v, SampleBatch) self.policy_batches = policy_batches - # Called count for uniformity with SampleBatch. Prefer to access this - # via the env_steps() method when possible for clarity. + # Called "count" for uniformity with SampleBatch. + # Prefer to access this via the `env_steps()` method when possible + # for clarity. self.count = env_steps @PublicAPI @@ -526,7 +527,8 @@ class MultiAgentBatch: """ if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches: return policy_batches[DEFAULT_POLICY_ID] - return MultiAgentBatch(policy_batches, env_steps) + return MultiAgentBatch( + policy_batches=policy_batches, env_steps=env_steps) @staticmethod @PublicAPI diff --git a/rllib/tests/test_external_env.py b/rllib/tests/test_external_env.py index 681d719ac..d35e5003d 100644 --- a/rllib/tests/test_external_env.py +++ b/rllib/tests/test_external_env.py @@ -9,8 +9,9 @@ from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.pg import PGTrainer from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy, - MockEnv) +from ray.rllib.evaluation.tests.test_rollout_worker import (BadPolicy, + MockPolicy) +from ray.rllib.examples.env.mock_env import MockEnv from ray.rllib.utils.test_utils import framework_iterator from ray.tune.registry import register_env diff --git a/rllib/tests/test_external_multi_agent_env.py b/rllib/tests/test_external_multi_agent_env.py index fe34f13fe..1f1e56cc0 100644 --- a/rllib/tests/test_external_multi_agent_env.py +++ b/rllib/tests/test_external_multi_agent_env.py @@ -5,8 +5,8 @@ import unittest import ray from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy from ray.rllib.examples.env.multi_agent import BasicMultiAgent -from ray.rllib.tests.test_rollout_worker import MockPolicy from ray.rllib.tests.test_external_env import make_simple_serving SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv) diff --git a/rllib/tests/test_multi_agent_env.py b/rllib/tests/test_multi_agent_env.py index 617aca620..a2c3dddff 100644 --- a/rllib/tests/test_multi_agent_env.py +++ b/rllib/tests/test_multi_agent_env.py @@ -12,8 +12,8 @@ from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent -from ray.rllib.tests.test_rollout_worker import MockPolicy from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.test_utils import check diff --git a/rllib/tests/test_perf.py b/rllib/tests/test_perf.py index 90148b043..4a4d7bdad 100644 --- a/rllib/tests/test_perf.py +++ b/rllib/tests/test_perf.py @@ -4,7 +4,7 @@ import unittest import ray from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.tests.test_rollout_worker import MockPolicy +from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy class TestPerf(unittest.TestCase):