diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 29775b6cd..a996e2d0f 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -18,7 +18,7 @@ from ray.tune.utils import flatten_dict logger = logging.getLogger(__name__) tf = None -VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32] +VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64] class Logger: diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 10168cfde..c70c1f114 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -206,9 +206,10 @@ class TrainTFMultiGPU: self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) - logger.debug("{} {}".format(i, - averaged(iter_extra_fetches))) - fetches[policy_id] = averaged(iter_extra_fetches) + if logger.getEffectiveLevel() <= logging.DEBUG: + avg = averaged(iter_extra_fetches) + logger.debug("{} {}".format(i, avg)) + fetches[policy_id] = averaged(iter_extra_fetches, axis=0) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index 25d9240eb..bc2e920b0 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ logger = logging.getLogger(__name__) -def averaged(kv): +def averaged(kv, axis=None): """Average the value lists of a dictionary. For non-scalar values, we simply pick the first value. @@ -27,7 +27,7 @@ def averaged(kv): out = {} for k, v in kv.items(): if v[0] is not None and not isinstance(v[0], dict): - out[k] = np.mean(v) + out[k] = np.mean(v, axis=axis) else: out[k] = v[0] return out