[RLlib] Issue #13802: Enhance metrics for multiagent->count_steps_by=agent_steps setting. (#14033)

This commit is contained in:
Sven Mika 2021-03-18 20:27:41 +01:00 committed by GitHub
parent 1d033fb552
commit c3a15ecc0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 48 deletions

View file

@ -22,35 +22,28 @@ class TestIMPALA(unittest.TestCase):
def test_impala_compilation(self):
"""Test whether an ImpalaTrainer can be built with both frameworks."""
config = impala.DEFAULT_CONFIG.copy()
config["model"]["lstm_use_prev_action"] = True
config["model"]["lstm_use_prev_reward"] = True
num_iterations = 1
env = "CartPole-v0"
for _ in framework_iterator(config):
local_cfg = config.copy()
for env in ["Pendulum-v0", "CartPole-v0"]:
print("Env={}".format(env))
print("w/o LSTM")
# Test w/o LSTM.
local_cfg["model"]["use_lstm"] = False
local_cfg["num_aggregation_workers"] = 0
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(trainer)
trainer.stop()
# Test w/ LSTM.
print("w/ LSTM")
local_cfg["model"]["use_lstm"] = True
local_cfg["model"]["lstm_use_prev_action"] = True
local_cfg["model"]["lstm_use_prev_reward"] = True
local_cfg["num_aggregation_workers"] = 1
for lstm in [False, True]:
local_cfg["num_aggregation_workers"] = 0 if not lstm else 1
local_cfg["model"]["use_lstm"] = lstm
print("lstm={} aggregation-worker={}".format(
lstm, local_cfg["num_aggregation_workers"]))
# Test with and w/o aggregation workers (this has nothing
# to do with LSTMs, though).
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(
trainer,
include_state=True,
include_prev_action_reward=True)
include_state=lstm,
include_prev_action_reward=lstm,
)
trainer.stop()
def test_impala_lr_schedule(self):

View file

@ -17,6 +17,8 @@ from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
SampleBatch
@ -170,10 +172,10 @@ class TestRolloutWorker(unittest.TestCase):
policy = agent.get_policy()
for i in range(3):
result = agent.train()
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("{}={}".format(STEPS_TRAINED_COUNTER,
result["info"][STEPS_TRAINED_COUNTER]))
print("{}={}".format(STEPS_SAMPLED_COUNTER,
result["info"][STEPS_SAMPLED_COUNTER]))
global_timesteps = policy.global_timestep
print("global_timesteps={}".format(global_timesteps))
expected_lr = \

View file

@ -419,9 +419,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
results = None
for i in range(num_iterations):
results = trainer.train()
self.assertGreater(results["timesteps_total"],
self.assertGreater(results["agent_timesteps_total"],
num_iterations * config["train_batch_size"])
self.assertLess(results["timesteps_total"],
self.assertLess(results["agent_timesteps_total"],
(num_iterations + 1) * config["train_batch_size"])
trainer.stop()

View file

@ -5,7 +5,9 @@ from ray.util.iter_metrics import MetricsContext
# Counters for training progress (keys for metrics.counters).
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
STEPS_TRAINED_COUNTER = "num_steps_trained"
AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"
# Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts"

View file

@ -3,8 +3,8 @@ import time
from ray.util.iter import LocalIterator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
_get_shared_metrics
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.evaluation.worker_set import WorkerSet
@ -103,6 +103,8 @@ class CollectMetrics:
res.update({
"num_healthy_workers": len(self.workers.remote_workers()),
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
"agent_timesteps_total": metrics.counters.get(
AGENT_STEPS_SAMPLED_COUNTER, 0),
})
res["timers"] = timers
res["info"] = info

View file

@ -7,9 +7,9 @@ from ray.util.iter_metrics import SharedMetrics
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, LEARNER_INFO, \
SAMPLE_TIMER, GRAD_WAIT_TIMER, _check_sample_batch_type, \
_get_shared_metrics
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \
_check_sample_batch_type, _get_shared_metrics
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.sgd import standardized
@ -60,6 +60,11 @@ def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
def report_timesteps(batch):
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
batch.agent_steps()
else:
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
return batch
if not workers.remote_workers():

View file

@ -9,11 +9,11 @@ from ray.rllib.evaluation.metrics import extract_stats, get_learner_stats, \
LEARNER_STATS_KEY
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import \
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \
APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, WORKER_UPDATE_TIMER, \
LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, LAST_TARGET_UPDATE_TS, \
NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type, \
_get_shared_metrics
AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \
LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \
LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \
_get_global_vars, _get_shared_metrics
from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
@ -76,6 +76,9 @@ class TrainOneStep:
info, "custom_metrics")
learn_timer.push_units_processed(batch.count)
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[
AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
@ -236,6 +239,7 @@ class TrainTFMultiGPU:
learn_timer.push_units_processed(samples.count)
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
metrics.info[LEARNER_INFO] = fetches
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:

View file

@ -3,15 +3,16 @@ import platform
from typing import List, Dict, Any
import ray
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
_get_shared_metrics
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actors import create_colocated
from ray.rllib.utils.typing import SampleBatchType, ModelWeights
from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
from_actors, LocalIterator
from ray.rllib.utils.typing import SampleBatchType, ModelWeights
from ray.rllib.evaluation.worker_set import WorkerSet
logger = logging.getLogger(__name__)
@ -100,6 +101,11 @@ def gather_experiences_tree_aggregation(workers: WorkerSet,
def record_steps_sampled(batch):
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
batch.agent_steps()
else:
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
return batch
return train_batches.gather_async().for_each(record_steps_sampled)

View file

@ -2,6 +2,8 @@ import unittest
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER
from ray.rllib.utils.test_utils import framework_iterator
@ -28,8 +30,8 @@ class TestDistributedExecution(unittest.TestCase):
assert isinstance(result, dict)
assert "info" in result
assert "learner" in result["info"]
assert "num_steps_sampled" in result["info"]
assert "num_steps_trained" in result["info"]
assert STEPS_SAMPLED_COUNTER in result["info"]
assert STEPS_TRAINED_COUNTER in result["info"]
assert "timers" in result
assert "learn_time_ms" in result["timers"]
assert "learn_throughput" in result["timers"]

View file

@ -7,6 +7,8 @@ import ray
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
@ -122,11 +124,11 @@ def test_rollouts(ray_start_regular_shared):
a = ParallelRollouts(workers, mode="bulk_sync")
assert next(a).count == 200
counters = a.shared_metrics.get().counters
assert counters["num_steps_sampled"] == 200, counters
assert counters[STEPS_SAMPLED_COUNTER] == 200, counters
a = ParallelRollouts(workers, mode="async")
assert next(a).count == 100
counters = a.shared_metrics.get().counters
assert counters["num_steps_sampled"] == 100, counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
@ -135,7 +137,7 @@ def test_rollouts_local(ray_start_regular_shared):
a = ParallelRollouts(workers, mode="bulk_sync")
assert next(a).count == 100
counters = a.shared_metrics.get().counters
assert counters["num_steps_sampled"] == 100, counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
@ -163,7 +165,7 @@ def test_async_grads(ray_start_regular_shared):
res1 = next(a)
assert isinstance(res1, tuple) and len(res1) == 2, res1
counters = a.shared_metrics.get().counters
assert counters["num_steps_sampled"] == 100, counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
@ -176,8 +178,8 @@ def test_train_one_step(ray_start_regular_shared):
assert DEFAULT_POLICY_ID in stats
assert "learner_stats" in stats[DEFAULT_POLICY_ID]
counters = a.shared_metrics.get().counters
assert counters["num_steps_sampled"] == 100, counters
assert counters["num_steps_trained"] == 100, counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
assert counters[STEPS_TRAINED_COUNTER] == 100, counters
timers = a.shared_metrics.get().timers
assert "learn" in timers
workers.stop()