mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Ensure stats are consistently reported across all algos (#4445)
This commit is contained in:
parent
2871609296
commit
09b2961750
18 changed files with 70 additions and 28 deletions
|
@ -8,6 +8,7 @@ import tensorflow as tf
|
|||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
|
@ -110,7 +111,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
self.stats_fetches = {
|
||||
"stats": {
|
||||
LEARNER_STATS_KEY: {
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"policy_entropy": self.loss.entropy,
|
||||
|
|
|
@ -11,6 +11,7 @@ import ray
|
|||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import (
|
||||
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
|
@ -446,7 +447,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.stats,
|
||||
LEARNER_STATS_KEY: self.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
|
|
@ -9,6 +9,7 @@ import tensorflow as tf
|
|||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
|
@ -454,7 +455,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
|||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.loss.stats,
|
||||
LEARNER_STATS_KEY: self.loss.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
|
|
@ -11,6 +11,7 @@ import ray
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
|
@ -296,7 +297,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = {
|
||||
"stats": dict({
|
||||
LEARNER_STATS_KEY: dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
|
|
|
@ -6,6 +6,7 @@ import tensorflow as tf
|
|||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
|
@ -141,7 +142,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
|||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
|
|
|
@ -13,6 +13,7 @@ import gym
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -406,7 +407,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
values_batched = make_time_major(
|
||||
values, drop_last=self.config["vtrace"])
|
||||
self.stats_fetches = {
|
||||
"stats": dict({
|
||||
LEARNER_STATS_KEY: dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
|
|
|
@ -127,7 +127,7 @@ class PPOAgent(Agent):
|
|||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(fetches, **res.get("info", {})))
|
||||
info=res.get("info", {}))
|
||||
|
||||
# Warn about bad clipping configs
|
||||
if self.config["vf_clip_param"] <= 0:
|
||||
|
@ -138,7 +138,7 @@ class PPOAgent(Agent):
|
|||
rew_scale = round(
|
||||
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
|
||||
0)
|
||||
if rew_scale > 100:
|
||||
if rew_scale > 200:
|
||||
logger.warning(
|
||||
"The magnitude of your environment rewards are more than "
|
||||
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
|
@ -332,7 +333,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.distributions import Categorical
|
|||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -295,7 +296,7 @@ class QMixPolicyGraph(PolicyGraph):
|
|||
mask_elems,
|
||||
"target_mean": (targets * mask).sum().item() / mask_elems,
|
||||
}
|
||||
return {"stats": stats}, {}
|
||||
return {LEARNER_STATS_KEY: stats}, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
|
|
|
@ -8,12 +8,37 @@ import collections
|
|||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import RolloutMetrics
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# By convention, metrics from optimizing the loss can be reported in the
|
||||
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
|
||||
LEARNER_STATS_KEY = "learner_stats"
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_learner_stats(grad_info):
|
||||
"""Return optimization stats reported from the policy graph.
|
||||
|
||||
Example:
|
||||
>>> grad_info = evaluator.learn_on_batch(samples)
|
||||
>>> print(get_stats(grad_info))
|
||||
{"vf_loss": ..., "policy_loss": ...}
|
||||
"""
|
||||
|
||||
if LEARNER_STATS_KEY in grad_info:
|
||||
return grad_info[LEARNER_STATS_KEY]
|
||||
|
||||
multiagent_stats = {}
|
||||
for k, v in grad_info.items():
|
||||
if type(v) is dict:
|
||||
if LEARNER_STATS_KEY in v:
|
||||
multiagent_stats[k] = v[LEARNER_STATS_KEY]
|
||||
|
||||
return multiagent_stats
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||
|
@ -135,6 +160,8 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
|||
def _partition(episodes):
|
||||
"""Divides metrics data into true rollouts vs off-policy estimates."""
|
||||
|
||||
from ray.rllib.evaluation.sampler import RolloutMetrics
|
||||
|
||||
rollouts, estimates = [], []
|
||||
for e in episodes:
|
||||
if isinstance(e, RolloutMetrics):
|
||||
|
|
|
@ -10,6 +10,7 @@ import numpy as np
|
|||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
@ -263,7 +264,7 @@ class TFPolicyGraph(PolicyGraph):
|
|||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
return {} # e.g, td error
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_apply_grad_feed_dict(self):
|
||||
|
@ -424,9 +425,12 @@ class TFPolicyGraph(PolicyGraph):
|
|||
|
||||
def _get_grad_and_stats_fetches(self):
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
if LEARNER_STATS_KEY not in fetches:
|
||||
raise ValueError(
|
||||
"Grad fetches should contain 'stats': {...} entry")
|
||||
if self._stats_fetches:
|
||||
fetches["stats"] = dict(self._stats_fetches,
|
||||
**fetches.get("stats", {}))
|
||||
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
|
||||
**fetches[LEARNER_STATS_KEY])
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
@ -49,9 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
|||
|
||||
gradient, info = ray.get(future)
|
||||
e = pending_gradients.pop(future)
|
||||
|
||||
if "stats" in info:
|
||||
self.learner_stats = info["stats"]
|
||||
self.learner_stats = get_learner_stats(info)
|
||||
|
||||
if gradient is not None:
|
||||
with self.apply_timer:
|
||||
|
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
|
@ -403,8 +404,7 @@ class LearnerThread(threading.Thread):
|
|||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid].data.get("batch_indexes"),
|
||||
info.get("td_error"))
|
||||
if "stats" in info:
|
||||
self.stats[pid] = info["stats"]
|
||||
self.stats[pid] = get_learner_stats(info)
|
||||
self.outqueue.put((ra, prio_dict, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
|
|
@ -15,6 +15,7 @@ import threading
|
|||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
|
@ -286,7 +287,7 @@ class LearnerThread(threading.Thread):
|
|||
with self.grad_timer:
|
||||
fetches = self.local_evaluator.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = fetches.get("stats", {})
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
self.outqueue.put(batch.count)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
@ -373,7 +374,7 @@ class TFMultiGPULearner(LearnerThread):
|
|||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = fetches.get("stats", {})
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
|
|
@ -9,6 +9,7 @@ from collections import defaultdict
|
|||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
|
@ -189,7 +190,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
batch_fetches = optimizer.optimize(
|
||||
self.sess, permutation[batch_index] *
|
||||
self.per_device_batch_size)
|
||||
for k, v in batch_fetches.items():
|
||||
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
logger.debug("{} {}".format(i,
|
||||
_averaged(iter_extra_fetches)))
|
||||
|
@ -197,6 +198,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
|
||||
self.num_steps_sampled += samples.count
|
||||
self.num_steps_trained += tuples_per_device * len(self.devices)
|
||||
self.learner_stats = fetches
|
||||
return fetches
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
|
@ -208,6 +210,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
||||
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
||||
3),
|
||||
"learner": self.learner_stats,
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import print_function
|
|||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -97,8 +98,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
|||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
self.learner_stats[policy_id] = get_learner_stats(info)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
self.num_steps_trained += samples.count
|
||||
return info_dict
|
||||
|
|
|
@ -9,6 +9,7 @@ import ray
|
|||
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
|
||||
PrioritizedReplayBuffer
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -128,8 +129,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
|||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
self.learner_stats[policy_id] = get_learner_stats(info)
|
||||
replay_buffer = self.replay_buffers[policy_id]
|
||||
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||
td_error = info["td_error"]
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
|
||||
import ray
|
||||
import logging
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -55,8 +56,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
|||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.learn_on_batch(samples)
|
||||
if "stats" in fetches:
|
||||
self.learner_stats = fetches["stats"]
|
||||
self.learner_stats = get_learner_stats(fetches)
|
||||
if self.num_sgd_iter > 1:
|
||||
logger.debug("{} {}".format(i, fetches))
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
|
Loading…
Add table
Reference in a new issue