ray/rllib/utils/debug/summary.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

80 lines
2.3 KiB
Python
Raw Normal View History

import numpy as np
import pprint
from typing import Any, Mapping
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI
_printer = pprint.PrettyPrinter(indent=2, width=60)
@DeveloperAPI
def summarize(obj: Any) -> Any:
"""Return a pretty-formatted string for an object.
This has special handling for pretty-formatting of commonly used data types
in RLlib, such as SampleBatch, numpy arrays, etc.
Args:
obj: The object to format.
Returns:
The summarized object.
"""
return _printer.pformat(_summarize(obj))
def _summarize(obj):
if isinstance(obj, Mapping):
return {k: _summarize(v) for k, v in obj.items()}
elif hasattr(obj, "_asdict"):
return {
"type": obj.__class__.__name__,
"data": _summarize(obj._asdict()),
}
elif isinstance(obj, list):
return [_summarize(x) for x in obj]
elif isinstance(obj, tuple):
return tuple(_summarize(x) for x in obj)
elif isinstance(obj, np.ndarray):
if obj.size == 0:
return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype))
elif obj.dtype == object or obj.dtype.type is np.str_:
return _StringValue(
"np.ndarray({}, dtype={}, head={})".format(
obj.shape, obj.dtype, _summarize(obj[0])
)
)
else:
return _StringValue(
"np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
obj.shape,
obj.dtype,
round(float(np.min(obj)), 3),
round(float(np.max(obj)), 3),
round(float(np.mean(obj)), 3),
)
)
elif isinstance(obj, MultiAgentBatch):
return {
"type": "MultiAgentBatch",
"policy_batches": _summarize(obj.policy_batches),
"count": obj.count,
}
elif isinstance(obj, SampleBatch):
return {
"type": "SampleBatch",
"data": {k: _summarize(v) for k, v in obj.items()},
}
else:
return obj
class _StringValue:
def __init__(self, value):
self.value = value
def __repr__(self):
return self.value