mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Feature/histograms in tensorboard (#6942)
* Added histogram functionality to custom metrics infrastructure (another tab in tensorboard) * updated example to include histogram metric * added histograms to TBXLogger * add episode rewards * lint Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
df518849ed
commit
dc7a555260
6 changed files with 44 additions and 16 deletions
|
@ -221,12 +221,19 @@ class TF2Logger(Logger):
|
|||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
from tensorboardX.summary import make_histogram
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
values = [
|
||||
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
]
|
||||
values = []
|
||||
for attr, value in flat_result.items():
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
elif type(value) is list and len(value) > 0:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
histo=make_histogram(values=np.array(value), bins=10)))
|
||||
return values
|
||||
|
||||
|
||||
|
@ -342,14 +349,18 @@ class TBXLogger(Logger):
|
|||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
valid_result = {
|
||||
"/".join(path + [attr]): value
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
}
|
||||
valid_result = {}
|
||||
for attr, value in flat_result.items():
|
||||
full_attr = "/".join(path + [attr])
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
valid_result[full_attr] = value
|
||||
self._file_writer.add_scalar(
|
||||
full_attr, value, global_step=step)
|
||||
elif type(value) is list and len(value) > 0:
|
||||
valid_result[full_attr] = value
|
||||
self._file_writer.add_histogram(
|
||||
full_attr, value, global_step=step)
|
||||
|
||||
for attr, value in valid_result.items():
|
||||
self._file_writer.add_scalar(attr, value, global_step=step)
|
||||
self.last_result = valid_result
|
||||
self._file_writer.flush()
|
||||
|
||||
|
@ -501,6 +512,7 @@ class _SafeFallbackEncoder(json.JSONEncoder):
|
|||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
result.update(hist_stats=None) # drop hist_stats from pretty print
|
||||
out = {}
|
||||
for k, v in result.items():
|
||||
if v is not None:
|
||||
|
|
|
@ -47,6 +47,7 @@ class MultiAgentEpisode:
|
|||
self.agent_rewards = defaultdict(float)
|
||||
self.custom_metrics = {}
|
||||
self.user_data = {}
|
||||
self.hist_data = {}
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._next_agent_index = 0
|
||||
|
|
|
@ -97,6 +97,7 @@ def summarize_episodes(episodes, new_episodes):
|
|||
policy_rewards = collections.defaultdict(list)
|
||||
custom_metrics = collections.defaultdict(list)
|
||||
perf_stats = collections.defaultdict(list)
|
||||
hist_stats = collections.defaultdict(list)
|
||||
for episode in episodes:
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
|
@ -107,6 +108,8 @@ def summarize_episodes(episodes, new_episodes):
|
|||
for (_, policy_id), reward in episode.agent_rewards.items():
|
||||
if policy_id != DEFAULT_POLICY_ID:
|
||||
policy_rewards[policy_id].append(reward)
|
||||
for k, v in episode.hist_data.items():
|
||||
hist_stats[k] += v
|
||||
if episode_rewards:
|
||||
min_reward = min(episode_rewards)
|
||||
max_reward = max(episode_rewards)
|
||||
|
@ -116,6 +119,10 @@ def summarize_episodes(episodes, new_episodes):
|
|||
avg_reward = np.mean(episode_rewards)
|
||||
avg_length = np.mean(episode_lengths)
|
||||
|
||||
# Show as histogram distributions.
|
||||
hist_stats["episode_reward"] = episode_rewards
|
||||
hist_stats["episode_lengths"] = episode_lengths
|
||||
|
||||
policy_reward_min = {}
|
||||
policy_reward_mean = {}
|
||||
policy_reward_max = {}
|
||||
|
@ -124,9 +131,12 @@ def summarize_episodes(episodes, new_episodes):
|
|||
policy_reward_mean[policy_id] = np.mean(rewards)
|
||||
policy_reward_max[policy_id] = np.max(rewards)
|
||||
|
||||
# Show as histogram distributions.
|
||||
hist_stats["policy_{}_reward".format(policy_id)] = rewards
|
||||
|
||||
for k, v_list in custom_metrics.copy().items():
|
||||
custom_metrics[k + "_mean"] = np.mean(v_list)
|
||||
filt = [v for v in v_list if not np.isnan(v)]
|
||||
custom_metrics[k + "_mean"] = np.mean(filt)
|
||||
if filt:
|
||||
custom_metrics[k + "_min"] = np.min(filt)
|
||||
custom_metrics[k + "_max"] = np.max(filt)
|
||||
|
@ -158,6 +168,7 @@ def summarize_episodes(episodes, new_episodes):
|
|||
policy_reward_max=policy_reward_max,
|
||||
policy_reward_mean=policy_reward_mean,
|
||||
custom_metrics=dict(custom_metrics),
|
||||
hist_stats=dict(hist_stats),
|
||||
sampler_perf=dict(perf_stats),
|
||||
off_policy_estimator=dict(estimators))
|
||||
|
||||
|
|
|
@ -3,5 +3,6 @@ import collections
|
|||
# Define this in its own file, see #5125
|
||||
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
"perf_stats", "hist_data"
|
||||
])
|
||||
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {})
|
||||
|
|
|
@ -390,7 +390,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
outputs.append(
|
||||
RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics, {}))
|
||||
episode.custom_metrics, {},
|
||||
episode.hist_data))
|
||||
else:
|
||||
hit_horizon = False
|
||||
all_done = False
|
||||
|
@ -620,7 +621,7 @@ def _fetch_atari_metrics(base_env):
|
|||
if not monitor:
|
||||
return None
|
||||
for eps_rew, eps_len in monitor.next_episode_results():
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}, {}))
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew))
|
||||
return atari_out
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ def on_episode_start(info):
|
|||
episode = info["episode"]
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode.user_data["pole_angles"] = []
|
||||
episode.hist_data["pole_angles"] = []
|
||||
|
||||
|
||||
def on_episode_step(info):
|
||||
|
@ -31,6 +32,7 @@ def on_episode_end(info):
|
|||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
||||
|
||||
|
||||
def on_sample_end(info):
|
||||
|
|
Loading…
Add table
Reference in a new issue