Fix multi-GPU histogram metrics for > 0D tensors.
This commit is contained in:
Sven Mika 2020-05-15 21:43:27 +02:00 committed by GitHub
parent 4633d81c39
commit c9435cad43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 6 deletions

View file

@ -18,7 +18,7 @@ from ray.tune.utils import flatten_dict
logger = logging.getLogger(__name__)
tf = None
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
class Logger:

View file

@ -206,9 +206,10 @@ class TrainTFMultiGPU:
self.per_device_batch_size)
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i,
averaged(iter_extra_fetches)))
fetches[policy_id] = averaged(iter_extra_fetches)
if logger.getEffectiveLevel() <= logging.DEBUG:
avg = averaged(iter_extra_fetches)
logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)

View file

@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
logger = logging.getLogger(__name__)
def averaged(kv):
def averaged(kv, axis=None):
"""Average the value lists of a dictionary.
For non-scalar values, we simply pick the first value.
@ -27,7 +27,7 @@ def averaged(kv):
out = {}
for k, v in kv.items():
if v[0] is not None and not isinstance(v[0], dict):
out[k] = np.mean(v)
out[k] = np.mean(v, axis=axis)
else:
out[k] = v[0]
return out