ray/rllib/utils/metrics/learner_info.py

84 lines
3.1 KiB
Python

from collections import defaultdict
import numpy as np
import tree # pip install dm_tree
from typing import Dict
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.typing import PolicyID
# Instant metrics (keys for metrics.info).
LEARNER_INFO = "learner"
# By convention, metrics from optimizing the loss can be reported in the
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
class LearnerInfoBuilder:
def __init__(self, num_devices: int = 1):
self.num_devices = num_devices
self.results_all_towers = defaultdict(list)
self.is_finalized = False
def add_learn_on_batch_results(
self,
results: Dict,
policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
"""Adds a policy.learn_on_(loaded)?_batch() result to this builder.
Args:
results: The results returned by Policy.learn_on_batch or
Policy.learn_on_loaded_batch.
policy_id: The policy's ID, whose learn_on_(loaded)_batch method
returned `results`.
"""
assert not self.is_finalized, \
"LearnerInfo already finalized! Cannot add more results."
# No towers: Single CPU.
if "tower_0" not in results:
self.results_all_towers[policy_id].append(results)
# Multi-GPU case:
else:
self.results_all_towers[policy_id].append(
tree.map_structure_with_path(
lambda p, *s: all_tower_reduce(p, *s),
*(results.pop("tower_{}".format(tower_num))
for tower_num in range(self.num_devices))))
for k, v in results.items():
if k == LEARNER_STATS_KEY:
for k1, v1 in results[k].items():
self.results_all_towers[policy_id][-1][
LEARNER_STATS_KEY][k1] = v1
else:
self.results_all_towers[policy_id][-1][k] = v
def finalize(self):
self.is_finalized = True
info = {}
for policy_id, results_all_towers in self.results_all_towers.items():
# Reduce mean across all minibatch SGD steps (axis=0 to keep
# all shapes as-is).
info[policy_id] = tree.map_structure(
lambda *s: None if s[0] is None else np.nanmean(s, axis=0),
*results_all_towers)
return info
def all_tower_reduce(path, *tower_data):
"""Reduces stats across towers based on their stats-dict paths."""
# TD-errors: Need to stay per batch item in order to be able to update
# each item's weight in a prioritized replay buffer.
if len(path) == 1 and path[0] == "td_error":
return np.concatenate(tower_data, axis=0)
# Min stats: Reduce min.
if path[-1].startswith("min_"):
return np.nanmin(tower_data)
# Max stats: Reduce max.
elif path[-1].startswith("max_"):
return np.nanmax(tower_data)
# Everything else: Reduce mean.
return np.nanmean(tower_data)