From c3a15ecc0ff26ae057415b04b48b9611b086ed25 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 18 Mar 2021 20:27:41 +0100 Subject: [PATCH] [RLlib] Issue #13802: Enhance metrics for `multiagent->count_steps_by=agent_steps` setting. (#14033) --- rllib/agents/impala/tests/test_impala.py | 33 ++++++++----------- rllib/evaluation/tests/test_rollout_worker.py | 10 +++--- .../tests/test_trajectory_view_api.py | 4 +-- rllib/execution/common.py | 2 ++ rllib/execution/metric_ops.py | 6 ++-- rllib/execution/rollout_ops.py | 11 +++++-- rllib/execution/train_ops.py | 14 +++++--- rllib/execution/tree_agg.py | 14 +++++--- rllib/tests/test_exec_api.py | 6 ++-- rllib/tests/test_execution.py | 14 ++++---- 10 files changed, 66 insertions(+), 48 deletions(-) diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index 7d724b204..adf85f1a1 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -22,35 +22,28 @@ class TestIMPALA(unittest.TestCase): def test_impala_compilation(self): """Test whether an ImpalaTrainer can be built with both frameworks.""" config = impala.DEFAULT_CONFIG.copy() + config["model"]["lstm_use_prev_action"] = True + config["model"]["lstm_use_prev_reward"] = True num_iterations = 1 + env = "CartPole-v0" for _ in framework_iterator(config): local_cfg = config.copy() - for env in ["Pendulum-v0", "CartPole-v0"]: - print("Env={}".format(env)) - print("w/o LSTM") - # Test w/o LSTM. - local_cfg["model"]["use_lstm"] = False - local_cfg["num_aggregation_workers"] = 0 - trainer = impala.ImpalaTrainer(config=local_cfg, env=env) - for i in range(num_iterations): - print(trainer.train()) - check_compute_single_action(trainer) - trainer.stop() - - # Test w/ LSTM. - print("w/ LSTM") - local_cfg["model"]["use_lstm"] = True - local_cfg["model"]["lstm_use_prev_action"] = True - local_cfg["model"]["lstm_use_prev_reward"] = True - local_cfg["num_aggregation_workers"] = 1 + for lstm in [False, True]: + local_cfg["num_aggregation_workers"] = 0 if not lstm else 1 + local_cfg["model"]["use_lstm"] = lstm + print("lstm={} aggregation-worker={}".format( + lstm, local_cfg["num_aggregation_workers"])) + # Test with and w/o aggregation workers (this has nothing + # to do with LSTMs, though). trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): print(trainer.train()) check_compute_single_action( trainer, - include_state=True, - include_prev_action_reward=True) + include_state=lstm, + include_prev_action_reward=lstm, + ) trainer.stop() def test_impala_lr_schedule(self): diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 8d45f5be6..088f23e28 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -17,6 +17,8 @@ from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2 from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ + STEPS_TRAINED_COUNTER from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \ SampleBatch @@ -170,10 +172,10 @@ class TestRolloutWorker(unittest.TestCase): policy = agent.get_policy() for i in range(3): result = agent.train() - print("num_steps_trained={}".format( - result["info"]["num_steps_trained"])) - print("num_steps_sampled={}".format( - result["info"]["num_steps_sampled"])) + print("{}={}".format(STEPS_TRAINED_COUNTER, + result["info"][STEPS_TRAINED_COUNTER])) + print("{}={}".format(STEPS_SAMPLED_COUNTER, + result["info"][STEPS_SAMPLED_COUNTER])) global_timesteps = policy.global_timestep print("global_timesteps={}".format(global_timesteps)) expected_lr = \ diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 1c56ef2b9..19e8de945 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -419,9 +419,9 @@ class TestTrajectoryViewAPI(unittest.TestCase): results = None for i in range(num_iterations): results = trainer.train() - self.assertGreater(results["timesteps_total"], + self.assertGreater(results["agent_timesteps_total"], num_iterations * config["train_batch_size"]) - self.assertLess(results["timesteps_total"], + self.assertLess(results["agent_timesteps_total"], (num_iterations + 1) * config["train_batch_size"]) trainer.stop() diff --git a/rllib/execution/common.py b/rllib/execution/common.py index b12e557d5..3349541da 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -5,7 +5,9 @@ from ray.util.iter_metrics import MetricsContext # Counters for training progress (keys for metrics.counters). STEPS_SAMPLED_COUNTER = "num_steps_sampled" +AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled" STEPS_TRAINED_COUNTER = "num_steps_trained" +AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained" # Counters to track target network updates. LAST_TARGET_UPDATE_TS = "last_target_update_ts" diff --git a/rllib/execution/metric_ops.py b/rllib/execution/metric_ops.py index 06857f674..7148d8350 100644 --- a/rllib/execution/metric_ops.py +++ b/rllib/execution/metric_ops.py @@ -3,8 +3,8 @@ import time from ray.util.iter import LocalIterator from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes -from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - _get_shared_metrics +from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ + STEPS_SAMPLED_COUNTER, _get_shared_metrics from ray.rllib.evaluation.worker_set import WorkerSet @@ -103,6 +103,8 @@ class CollectMetrics: res.update({ "num_healthy_workers": len(self.workers.remote_workers()), "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], + "agent_timesteps_total": metrics.counters.get( + AGENT_STEPS_SAMPLED_COUNTER, 0), }) res["timers"] = timers res["info"] = info diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index baaa26357..46d6022ae 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -7,9 +7,9 @@ from ray.util.iter_metrics import SharedMetrics from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, LEARNER_INFO, \ - SAMPLE_TIMER, GRAD_WAIT_TIMER, _check_sample_batch_type, \ - _get_shared_metrics +from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ + STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \ + _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.sgd import standardized @@ -60,6 +60,11 @@ def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count + if isinstance(batch, MultiAgentBatch): + metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \ + batch.agent_steps() + else: + metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 7c5d005de..9b3314e62 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -9,11 +9,11 @@ from ray.rllib.evaluation.metrics import extract_stats, get_learner_stats, \ LEARNER_STATS_KEY from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import \ - STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \ - APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, WORKER_UPDATE_TIMER, \ - LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, LAST_TARGET_UPDATE_TS, \ - NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type, \ - _get_shared_metrics + AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \ + LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \ + LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \ + STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \ + _get_global_vars, _get_shared_metrics from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch @@ -76,6 +76,9 @@ class TrainOneStep: info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count + if isinstance(batch, MultiAgentBatch): + metrics.counters[ + AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps() # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): @@ -236,6 +239,7 @@ class TrainTFMultiGPU: learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count + metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index b04bee783..92add2efc 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -3,15 +3,16 @@ import platform from typing import List, Dict, Any import ray -from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ - _get_shared_metrics +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ + STEPS_SAMPLED_COUNTER, _get_shared_metrics from ray.rllib.execution.replay_ops import MixInReplay from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.typing import SampleBatchType, ModelWeights from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \ from_actors, LocalIterator -from ray.rllib.utils.typing import SampleBatchType, ModelWeights -from ray.rllib.evaluation.worker_set import WorkerSet logger = logging.getLogger(__name__) @@ -100,6 +101,11 @@ def gather_experiences_tree_aggregation(workers: WorkerSet, def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count + if isinstance(batch, MultiAgentBatch): + metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \ + batch.agent_steps() + else: + metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled) diff --git a/rllib/tests/test_exec_api.py b/rllib/tests/test_exec_api.py index df9a5015e..b415c4faa 100644 --- a/rllib/tests/test_exec_api.py +++ b/rllib/tests/test_exec_api.py @@ -2,6 +2,8 @@ import unittest import ray from ray.rllib.agents.a3c import A2CTrainer +from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ + STEPS_TRAINED_COUNTER from ray.rllib.utils.test_utils import framework_iterator @@ -28,8 +30,8 @@ class TestDistributedExecution(unittest.TestCase): assert isinstance(result, dict) assert "info" in result assert "learner" in result["info"] - assert "num_steps_sampled" in result["info"] - assert "num_steps_trained" in result["info"] + assert STEPS_SAMPLED_COUNTER in result["info"] + assert STEPS_TRAINED_COUNTER in result["info"] assert "timers" in result assert "learn_time_ms" in result["timers"] assert "learn_throughput" in result["timers"] diff --git a/rllib/tests/test_execution.py b/rllib/tests/test_execution.py index d97d12d40..d99fd27ae 100644 --- a/rllib/tests/test_execution.py +++ b/rllib/tests/test_execution.py @@ -7,6 +7,8 @@ import ray from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ + STEPS_TRAINED_COUNTER from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay @@ -122,11 +124,11 @@ def test_rollouts(ray_start_regular_shared): a = ParallelRollouts(workers, mode="bulk_sync") assert next(a).count == 200 counters = a.shared_metrics.get().counters - assert counters["num_steps_sampled"] == 200, counters + assert counters[STEPS_SAMPLED_COUNTER] == 200, counters a = ParallelRollouts(workers, mode="async") assert next(a).count == 100 counters = a.shared_metrics.get().counters - assert counters["num_steps_sampled"] == 100, counters + assert counters[STEPS_SAMPLED_COUNTER] == 100, counters workers.stop() @@ -135,7 +137,7 @@ def test_rollouts_local(ray_start_regular_shared): a = ParallelRollouts(workers, mode="bulk_sync") assert next(a).count == 100 counters = a.shared_metrics.get().counters - assert counters["num_steps_sampled"] == 100, counters + assert counters[STEPS_SAMPLED_COUNTER] == 100, counters workers.stop() @@ -163,7 +165,7 @@ def test_async_grads(ray_start_regular_shared): res1 = next(a) assert isinstance(res1, tuple) and len(res1) == 2, res1 counters = a.shared_metrics.get().counters - assert counters["num_steps_sampled"] == 100, counters + assert counters[STEPS_SAMPLED_COUNTER] == 100, counters workers.stop() @@ -176,8 +178,8 @@ def test_train_one_step(ray_start_regular_shared): assert DEFAULT_POLICY_ID in stats assert "learner_stats" in stats[DEFAULT_POLICY_ID] counters = a.shared_metrics.get().counters - assert counters["num_steps_sampled"] == 100, counters - assert counters["num_steps_trained"] == 100, counters + assert counters[STEPS_SAMPLED_COUNTER] == 100, counters + assert counters[STEPS_TRAINED_COUNTER] == 100, counters timers = a.shared_metrics.get().timers assert "learn" in timers workers.stop()