Fix multi-GPU histogram metrics for > 0D tensors.
This commit is contained in:
Sven Mika 2020-05-15 21:43:27 +02:00 committed by GitHub
parent 4633d81c39
commit c9435cad43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 6 deletions

View file

@ -18,7 +18,7 @@ from ray.tune.utils import flatten_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
tf = None tf = None
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32] VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
class Logger: class Logger:

View file

@ -206,9 +206,10 @@ class TrainTFMultiGPU:
self.per_device_batch_size) self.per_device_batch_size)
for k, v in batch_fetches[LEARNER_STATS_KEY].items(): for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v) iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i, if logger.getEffectiveLevel() <= logging.DEBUG:
averaged(iter_extra_fetches))) avg = averaged(iter_extra_fetches)
fetches[policy_id] = averaged(iter_extra_fetches) logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
load_timer.push_units_processed(samples.count) load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count)

View file

@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def averaged(kv): def averaged(kv, axis=None):
"""Average the value lists of a dictionary. """Average the value lists of a dictionary.
For non-scalar values, we simply pick the first value. For non-scalar values, we simply pick the first value.
@ -27,7 +27,7 @@ def averaged(kv):
out = {} out = {}
for k, v in kv.items(): for k, v in kv.items():
if v[0] is not None and not isinstance(v[0], dict): if v[0] is not None and not isinstance(v[0], dict):
out[k] = np.mean(v) out[k] = np.mean(v, axis=axis)
else: else:
out[k] = v[0] out[k] = v[0]
return out return out