[rllib] Feature/histograms in tensorboard (#6942)

* Added histogram functionality to custom metrics infrastructure (another tab in tensorboard)

* updated example to include histogram metric

* added histograms to TBXLogger

* add episode rewards

* lint

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
roireshef 2020-01-31 08:02:53 +02:00 committed by GitHub
parent df518849ed
commit dc7a555260
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 16 deletions

View file

@ -221,12 +221,19 @@ class TF2Logger(Logger):
def to_tf_values(result, path):
from tensorboardX.summary import make_histogram
flat_result = flatten_dict(result, delimiter="/")
values = [
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
]
values = []
for attr, value in flat_result.items():
if type(value) in VALID_SUMMARY_TYPES:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is list and len(value) > 0:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]),
histo=make_histogram(values=np.array(value), bins=10)))
return values
@ -342,14 +349,18 @@ class TBXLogger(Logger):
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {
"/".join(path + [attr]): value
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
}
valid_result = {}
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if type(value) in VALID_SUMMARY_TYPES:
valid_result[full_attr] = value
self._file_writer.add_scalar(
full_attr, value, global_step=step)
elif type(value) is list and len(value) > 0:
valid_result[full_attr] = value
self._file_writer.add_histogram(
full_attr, value, global_step=step)
for attr, value in valid_result.items():
self._file_writer.add_scalar(attr, value, global_step=step)
self.last_result = valid_result
self._file_writer.flush()
@ -501,6 +512,7 @@ class _SafeFallbackEncoder(json.JSONEncoder):
def pretty_print(result):
result = result.copy()
result.update(config=None) # drop config from pretty print
result.update(hist_stats=None) # drop hist_stats from pretty print
out = {}
for k, v in result.items():
if v is not None:

View file

@ -47,6 +47,7 @@ class MultiAgentEpisode:
self.agent_rewards = defaultdict(float)
self.custom_metrics = {}
self.user_data = {}
self.hist_data = {}
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._next_agent_index = 0

View file

@ -97,6 +97,7 @@ def summarize_episodes(episodes, new_episodes):
policy_rewards = collections.defaultdict(list)
custom_metrics = collections.defaultdict(list)
perf_stats = collections.defaultdict(list)
hist_stats = collections.defaultdict(list)
for episode in episodes:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
@ -107,6 +108,8 @@ def summarize_episodes(episodes, new_episodes):
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
policy_rewards[policy_id].append(reward)
for k, v in episode.hist_data.items():
hist_stats[k] += v
if episode_rewards:
min_reward = min(episode_rewards)
max_reward = max(episode_rewards)
@ -116,6 +119,10 @@ def summarize_episodes(episodes, new_episodes):
avg_reward = np.mean(episode_rewards)
avg_length = np.mean(episode_lengths)
# Show as histogram distributions.
hist_stats["episode_reward"] = episode_rewards
hist_stats["episode_lengths"] = episode_lengths
policy_reward_min = {}
policy_reward_mean = {}
policy_reward_max = {}
@ -124,9 +131,12 @@ def summarize_episodes(episodes, new_episodes):
policy_reward_mean[policy_id] = np.mean(rewards)
policy_reward_max[policy_id] = np.max(rewards)
# Show as histogram distributions.
hist_stats["policy_{}_reward".format(policy_id)] = rewards
for k, v_list in custom_metrics.copy().items():
custom_metrics[k + "_mean"] = np.mean(v_list)
filt = [v for v in v_list if not np.isnan(v)]
custom_metrics[k + "_mean"] = np.mean(filt)
if filt:
custom_metrics[k + "_min"] = np.min(filt)
custom_metrics[k + "_max"] = np.max(filt)
@ -158,6 +168,7 @@ def summarize_episodes(episodes, new_episodes):
policy_reward_max=policy_reward_max,
policy_reward_mean=policy_reward_mean,
custom_metrics=dict(custom_metrics),
hist_stats=dict(hist_stats),
sampler_perf=dict(perf_stats),
off_policy_estimator=dict(estimators))

View file

@ -3,5 +3,6 @@ import collections
# Define this in its own file, see #5125
RolloutMetrics = collections.namedtuple("RolloutMetrics", [
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
"perf_stats"
"perf_stats", "hist_data"
])
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {})

View file

@ -390,7 +390,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
outputs.append(
RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics, {}))
episode.custom_metrics, {},
episode.hist_data))
else:
hit_horizon = False
all_done = False
@ -620,7 +621,7 @@ def _fetch_atari_metrics(base_env):
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}, {}))
atari_out.append(RolloutMetrics(eps_len, eps_rew))
return atari_out

View file

@ -15,6 +15,7 @@ def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(info):
@ -31,6 +32,7 @@ def on_episode_end(info):
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(info):