[rllib] Ensure stats are consistently reported across all algos (#4445)

This commit is contained in:
Eric Liang 2019-03-27 15:40:15 -07:00 committed by GitHub
parent 2871609296
commit 09b2961750
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 70 additions and 28 deletions

View file

@ -8,6 +8,7 @@ import tensorflow as tf
import gym
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.policy_graph import PolicyGraph
@ -110,7 +111,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
max_seq_len=self.config["model"]["max_seq_len"])
self.stats_fetches = {
"stats": {
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"policy_entropy": self.loss.entropy,

View file

@ -11,6 +11,7 @@ import ray
import ray.experimental.tf_utils
from ray.rllib.agents.dqn.dqn_policy_graph import (
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
@ -446,7 +447,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
"stats": self.stats,
LEARNER_STATS_KEY: self.stats,
}
@override(PolicyGraph)

View file

@ -9,6 +9,7 @@ import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
@ -454,7 +455,7 @@ class DQNPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
"stats": self.loss.stats,
LEARNER_STATS_KEY: self.loss.stats,
}
@override(PolicyGraph)

View file

@ -11,6 +11,7 @@ import ray
import numpy as np
import tensorflow as tf
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
@ -296,7 +297,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess.run(tf.global_variables_initializer())
self.stats_fetches = {
"stats": dict({
LEARNER_STATS_KEY: dict({
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,

View file

@ -6,6 +6,7 @@ import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.policy_graph import PolicyGraph
@ -141,7 +142,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
return self.stats_fetches
return {LEARNER_STATS_KEY: self.stats_fetches}
@override(PolicyGraph)
def postprocess_trajectory(self,

View file

@ -13,6 +13,7 @@ import gym
import ray
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
@ -406,7 +407,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
values_batched = make_time_major(
values, drop_last=self.config["vtrace"])
self.stats_fetches = {
"stats": dict({
LEARNER_STATS_KEY: dict({
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,

View file

@ -127,7 +127,7 @@ class PPOAgent(Agent):
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {})))
info=res.get("info", {}))
# Warn about bad clipping configs
if self.config["vf_clip_param"] <= 0:
@ -138,7 +138,7 @@ class PPOAgent(Agent):
rew_scale = round(
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
0)
if rew_scale > 100:
if rew_scale > 200:
logger.warning(
"The magnitude of your environment rewards are more than "
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +

View file

@ -6,6 +6,7 @@ import logging
import tensorflow as tf
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
@ -332,7 +333,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
return self.stats_fetches
return {LEARNER_STATS_KEY: self.stats_fetches}
def update_kl(self, sampled_kl):
if sampled_kl > 2.0 * self.kl_target:

View file

@ -13,6 +13,7 @@ from torch.distributions import Categorical
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.catalog import ModelCatalog
@ -295,7 +296,7 @@ class QMixPolicyGraph(PolicyGraph):
mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
return {"stats": stats}, {}
return {LEARNER_STATS_KEY: stats}, {}
@override(PolicyGraph)
def get_initial_state(self):

View file

@ -8,12 +8,37 @@ import collections
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import RolloutMetrics
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
# By convention, metrics from optimizing the loss can be reported in the
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
@DeveloperAPI
def get_learner_stats(grad_info):
"""Return optimization stats reported from the policy graph.
Example:
>>> grad_info = evaluator.learn_on_batch(samples)
>>> print(get_stats(grad_info))
{"vf_loss": ..., "policy_loss": ...}
"""
if LEARNER_STATS_KEY in grad_info:
return grad_info[LEARNER_STATS_KEY]
multiagent_stats = {}
for k, v in grad_info.items():
if type(v) is dict:
if LEARNER_STATS_KEY in v:
multiagent_stats[k] = v[LEARNER_STATS_KEY]
return multiagent_stats
@DeveloperAPI
def collect_metrics(local_evaluator, remote_evaluators=[],
@ -135,6 +160,8 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
def _partition(episodes):
"""Divides metrics data into true rollouts vs off-policy estimates."""
from ray.rllib.evaluation.sampler import RolloutMetrics
rollouts, estimates = [], []
for e in episodes:
if isinstance(e, RolloutMetrics):

View file

@ -10,6 +10,7 @@ import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.annotations import override, DeveloperAPI
@ -263,7 +264,7 @@ class TFPolicyGraph(PolicyGraph):
@DeveloperAPI
def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
return {} # e.g, td error
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def extra_apply_grad_feed_dict(self):
@ -424,9 +425,12 @@ class TFPolicyGraph(PolicyGraph):
def _get_grad_and_stats_fetches(self):
fetches = self.extra_compute_grad_fetches()
if LEARNER_STATS_KEY not in fetches:
raise ValueError(
"Grad fetches should contain 'stats': {...} entry")
if self._stats_fetches:
fetches["stats"] = dict(self._stats_fetches,
**fetches.get("stats", {}))
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
**fetches[LEARNER_STATS_KEY])
return fetches
def _get_loss_inputs_dict(self, batch):

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
@ -49,9 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
gradient, info = ray.get(future)
e = pending_gradients.pop(future)
if "stats" in info:
self.learner_stats = info["stats"]
self.learner_stats = get_learner_stats(info)
if gradient is not None:
with self.apply_timer:

View file

@ -16,6 +16,7 @@ import numpy as np
from six.moves import queue
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
@ -403,8 +404,7 @@ class LearnerThread(threading.Thread):
prio_dict[pid] = (
replay.policy_batches[pid].data.get("batch_indexes"),
info.get("td_error"))
if "stats" in info:
self.stats[pid] = info["stats"]
self.stats[pid] = get_learner_stats(info)
self.outqueue.put((ra, prio_dict, replay.count))
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True

View file

@ -15,6 +15,7 @@ import threading
from six.moves import queue
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
@ -286,7 +287,7 @@ class LearnerThread(threading.Thread):
with self.grad_timer:
fetches = self.local_evaluator.learn_on_batch(batch)
self.weights_updated = True
self.stats = fetches.get("stats", {})
self.stats = get_learner_stats(fetches)
self.outqueue.put(batch.count)
self.learner_queue_size.push(self.inqueue.qsize())
@ -373,7 +374,7 @@ class TFMultiGPULearner(LearnerThread):
with self.grad_timer:
fetches = opt.optimize(self.sess, 0)
self.weights_updated = True
self.stats = fetches.get("stats", {})
self.stats = get_learner_stats(fetches)
if released:
self.idle_optimizers.put(opt)

View file

@ -9,6 +9,7 @@ from collections import defaultdict
import tensorflow as tf
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
@ -189,7 +190,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
for k, v in batch_fetches.items():
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i,
_averaged(iter_extra_fetches)))
@ -197,6 +198,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.num_steps_sampled += samples.count
self.num_steps_trained += tuples_per_device * len(self.devices)
self.learner_stats = fetches
return fetches
@override(PolicyOptimizer)
@ -208,6 +210,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
"learner": self.learner_stats,
})

View file

@ -5,6 +5,7 @@ from __future__ import print_function
import random
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
@ -97,8 +98,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
with self.grad_timer:
info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]
self.learner_stats[policy_id] = get_learner_stats(info)
self.grad_timer.push_units_processed(samples.count)
self.num_steps_trained += samples.count
return info_dict

View file

@ -9,6 +9,7 @@ import ray
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
PrioritizedReplayBuffer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
@ -128,8 +129,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
with self.grad_timer:
info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]
self.learner_stats[policy_id] = get_learner_stats(info)
replay_buffer = self.replay_buffers[policy_id]
if isinstance(replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]

View file

@ -4,6 +4,7 @@ from __future__ import print_function
import ray
import logging
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
@ -55,8 +56,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
with self.grad_timer:
for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.learn_on_batch(samples)
if "stats" in fetches:
self.learner_stats = fetches["stats"]
self.learner_stats = get_learner_stats(fetches)
if self.num_sgd_iter > 1:
logger.debug("{} {}".format(i, fetches))
self.grad_timer.push_units_processed(samples.count)