2018-07-01 00:05:08 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import collections
|
|
|
|
|
|
|
|
import ray
|
2018-08-20 15:28:03 -07:00
|
|
|
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
2018-07-01 00:05:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
2018-07-08 13:03:53 -07:00
|
|
|
"""Gathers episode metrics from PolicyEvaluator instances."""
|
2018-07-01 00:05:08 -07:00
|
|
|
|
|
|
|
episode_rewards = []
|
|
|
|
episode_lengths = []
|
|
|
|
policy_rewards = collections.defaultdict(list)
|
2018-07-19 15:30:36 -07:00
|
|
|
metric_lists = ray.get([
|
|
|
|
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
|
|
|
for a in remote_evaluators
|
|
|
|
])
|
2018-07-01 00:05:08 -07:00
|
|
|
metric_lists.append(local_evaluator.sampler.get_metrics())
|
|
|
|
for metrics in metric_lists:
|
|
|
|
for episode in metrics:
|
|
|
|
episode_lengths.append(episode.episode_length)
|
|
|
|
episode_rewards.append(episode.episode_reward)
|
|
|
|
for (_, policy_id), reward in episode.agent_rewards.items():
|
2018-08-20 15:28:03 -07:00
|
|
|
if policy_id != DEFAULT_POLICY_ID:
|
|
|
|
policy_rewards[policy_id].append(reward)
|
2018-07-01 00:05:08 -07:00
|
|
|
if episode_rewards:
|
|
|
|
min_reward = min(episode_rewards)
|
|
|
|
max_reward = max(episode_rewards)
|
|
|
|
else:
|
|
|
|
min_reward = float('nan')
|
|
|
|
max_reward = float('nan')
|
|
|
|
avg_reward = np.mean(episode_rewards)
|
|
|
|
avg_length = np.mean(episode_lengths)
|
|
|
|
|
|
|
|
for policy_id, rewards in policy_rewards.copy().items():
|
|
|
|
policy_rewards[policy_id] = np.mean(rewards)
|
|
|
|
|
2018-08-07 12:17:44 -07:00
|
|
|
return dict(
|
2018-07-01 00:05:08 -07:00
|
|
|
episode_reward_max=max_reward,
|
|
|
|
episode_reward_min=min_reward,
|
|
|
|
episode_reward_mean=avg_reward,
|
|
|
|
episode_len_mean=avg_length,
|
|
|
|
episodes_total=len(episode_lengths),
|
|
|
|
policy_reward_mean=dict(policy_rewards))
|