diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index ca2f67ac6..d37600ae9 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -8,6 +8,7 @@ import tensorflow as tf import gym import ray +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -110,7 +111,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph): max_seq_len=self.config["model"]["max_seq_len"]) self.stats_fetches = { - "stats": { + LEARNER_STATS_KEY: { "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "policy_entropy": self.loss.entropy, diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index dbc9bbc11..399aabc49 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -11,6 +11,7 @@ import ray import ray.experimental.tf_utils from ray.rllib.agents.dqn.dqn_policy_graph import ( _huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn) +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException @@ -446,7 +447,7 @@ class DDPGPolicyGraph(TFPolicyGraph): def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, - "stats": self.stats, + LEARNER_STATS_KEY: self.stats, } @override(PolicyGraph) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index f8e2f390b..318e0758b 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -9,6 +9,7 @@ import tensorflow as tf import tensorflow.contrib.layers as layers import ray +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException @@ -454,7 +455,7 @@ class DQNPolicyGraph(TFPolicyGraph): def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, - "stats": self.loss.stats, + LEARNER_STATS_KEY: self.loss.stats, } @override(PolicyGraph) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index a0e4a4c41..74064586c 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -11,6 +11,7 @@ import ray import numpy as np import tensorflow as tf from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule @@ -296,7 +297,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { - "stats": dict({ + LEARNER_STATS_KEY: dict({ "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy_graph.py index 7b66350d8..d1b159605 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy_graph.py @@ -6,6 +6,7 @@ import tensorflow as tf import ray from ray.rllib.models import ModelCatalog +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.utils.annotations import override from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -141,7 +142,7 @@ class MARWILPolicyGraph(TFPolicyGraph): @override(TFPolicyGraph) def extra_compute_grad_fetches(self): - return self.stats_fetches + return {LEARNER_STATS_KEY: self.stats_fetches} @override(PolicyGraph) def postprocess_trajectory(self, diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index c279d10bd..ae2aa9348 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -13,6 +13,7 @@ import gym import ray from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog @@ -406,7 +407,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): values_batched = make_time_major( values, drop_last=self.config["vtrace"]) self.stats_fetches = { - "stats": dict({ + LEARNER_STATS_KEY: dict({ "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 5f1e59977..b203179bc 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -127,7 +127,7 @@ class PPOAgent(Agent): res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=dict(fetches, **res.get("info", {}))) + info=res.get("info", {})) # Warn about bad clipping configs if self.config["vf_clip_param"] <= 0: @@ -138,7 +138,7 @@ class PPOAgent(Agent): rew_scale = round( abs(res["episode_reward_mean"]) / self.config["vf_clip_param"], 0) - if rew_scale > 100: + if rew_scale > 200: logger.warning( "The magnitude of your environment rewards are more than " "{}x the scale of `vf_clip_param`. ".format(rew_scale) + diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 96f969a83..e983c5d0c 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -6,6 +6,7 @@ import logging import tensorflow as tf import ray +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ @@ -332,7 +333,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): @override(TFPolicyGraph) def extra_compute_grad_fetches(self): - return self.stats_fetches + return {LEARNER_STATS_KEY: self.stats_fetches} def update_kl(self, sampled_kl): if sampled_kl > 2.0 * self.kl_target: diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy_graph.py index 4c8100175..d5fa10f49 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy_graph.py +++ b/python/ray/rllib/agents/qmix/qmix_policy_graph.py @@ -13,6 +13,7 @@ from torch.distributions import Categorical import ray from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.models.action_dist import TupleActions from ray.rllib.models.catalog import ModelCatalog @@ -295,7 +296,7 @@ class QMixPolicyGraph(PolicyGraph): mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } - return {"stats": stats}, {} + return {LEARNER_STATS_KEY: stats}, {} @override(PolicyGraph) def get_initial_state(self): diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index 5a9d3749c..a8fa64b1c 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -8,12 +8,37 @@ import collections import ray from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.evaluation.sampler import RolloutMetrics from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate from ray.rllib.utils.annotations import DeveloperAPI logger = logging.getLogger(__name__) +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +@DeveloperAPI +def get_learner_stats(grad_info): + """Return optimization stats reported from the policy graph. + + Example: + >>> grad_info = evaluator.learn_on_batch(samples) + >>> print(get_stats(grad_info)) + {"vf_loss": ..., "policy_loss": ...} + """ + + if LEARNER_STATS_KEY in grad_info: + return grad_info[LEARNER_STATS_KEY] + + multiagent_stats = {} + for k, v in grad_info.items(): + if type(v) is dict: + if LEARNER_STATS_KEY in v: + multiagent_stats[k] = v[LEARNER_STATS_KEY] + + return multiagent_stats + @DeveloperAPI def collect_metrics(local_evaluator, remote_evaluators=[], @@ -135,6 +160,8 @@ def summarize_episodes(episodes, new_episodes, num_dropped): def _partition(episodes): """Divides metrics data into true rollouts vs off-policy estimates.""" + from ray.rllib.evaluation.sampler import RolloutMetrics + rollouts, estimates = [], [] for e in episodes: if isinstance(e, RolloutMetrics): diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 8febf7738..6d99978f6 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -10,6 +10,7 @@ import numpy as np import ray import ray.experimental.tf_utils +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.utils.annotations import override, DeveloperAPI @@ -263,7 +264,7 @@ class TFPolicyGraph(PolicyGraph): @DeveloperAPI def extra_compute_grad_fetches(self): """Extra values to fetch and return from compute_gradients().""" - return {} # e.g, td error + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. @DeveloperAPI def extra_apply_grad_feed_dict(self): @@ -424,9 +425,12 @@ class TFPolicyGraph(PolicyGraph): def _get_grad_and_stats_fetches(self): fetches = self.extra_compute_grad_fetches() + if LEARNER_STATS_KEY not in fetches: + raise ValueError( + "Grad fetches should contain 'stats': {...} entry") if self._stats_fetches: - fetches["stats"] = dict(self._stats_fetches, - **fetches.get("stats", {})) + fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches, + **fetches[LEARNER_STATS_KEY]) return fetches def _get_loss_inputs_dict(self, batch): diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index b1e5ebe84..a9db3e00b 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import ray +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat @@ -49,9 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer): gradient, info = ray.get(future) e = pending_gradients.pop(future) - - if "stats" in info: - self.learner_stats = info["stats"] + self.learner_stats = get_learner_stats(info) if gradient is not None: with self.apply_timer: diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 72612ed2a..6438ef928 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -16,6 +16,7 @@ import numpy as np from six.moves import queue import ray +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer @@ -403,8 +404,7 @@ class LearnerThread(threading.Thread): prio_dict[pid] = ( replay.policy_batches[pid].data.get("batch_indexes"), info.get("td_error")) - if "stats" in info: - self.stats[pid] = info["stats"] + self.stats[pid] = get_learner_stats(info) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 22f33545b..0827d4598 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -15,6 +15,7 @@ import threading from six.moves import queue import ray +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer @@ -286,7 +287,7 @@ class LearnerThread(threading.Thread): with self.grad_timer: fetches = self.local_evaluator.learn_on_batch(batch) self.weights_updated = True - self.stats = fetches.get("stats", {}) + self.stats = get_learner_stats(fetches) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize()) @@ -373,7 +374,7 @@ class TFMultiGPULearner(LearnerThread): with self.grad_timer: fetches = opt.optimize(self.sess, 0) self.weights_updated = True - self.stats = fetches.get("stats", {}) + self.stats = get_learner_stats(fetches) if released: self.idle_optimizers.put(opt) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 0defc8fe8..be75dbb67 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -9,6 +9,7 @@ from collections import defaultdict import tensorflow as tf import ray +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer @@ -189,7 +190,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) - for k, v in batch_fetches.items(): + for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) @@ -197,6 +198,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) + self.learner_stats = fetches return fetches @override(PolicyOptimizer) @@ -208,6 +210,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): "grad_time_ms": round(1000 * self.grad_timer.mean, 3), "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), + "learner": self.learner_stats, }) diff --git a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py index 0e58a7e4f..846ec5216 100644 --- a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py @@ -5,6 +5,7 @@ from __future__ import print_function import random import ray +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch @@ -97,8 +98,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer): with self.grad_timer: info_dict = self.local_evaluator.learn_on_batch(samples) for policy_id, info in info_dict.items(): - if "stats" in info: - self.learner_stats[policy_id] = info["stats"] + self.learner_stats[policy_id] = get_learner_stats(info) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count return info_dict diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index e9b4304a9..21f00d7df 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -9,6 +9,7 @@ import ray from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \ PrioritizedReplayBuffer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.annotations import override @@ -128,8 +129,7 @@ class SyncReplayOptimizer(PolicyOptimizer): with self.grad_timer: info_dict = self.local_evaluator.learn_on_batch(samples) for policy_id, info in info_dict.items(): - if "stats" in info: - self.learner_stats[policy_id] = info["stats"] + self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index 2cefa5531..3321a04ae 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -4,6 +4,7 @@ from __future__ import print_function import ray import logging +from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch from ray.rllib.utils.annotations import override @@ -55,8 +56,7 @@ class SyncSamplesOptimizer(PolicyOptimizer): with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) - if "stats" in fetches: - self.learner_stats = fetches["stats"] + self.learner_stats = get_learner_stats(fetches) if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count)