mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
4633d81c39
commit
c9435cad43
3 changed files with 7 additions and 6 deletions
|
@ -18,7 +18,7 @@ from ray.tune.utils import flatten_dict
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
tf = None
|
||||
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]
|
||||
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
|
||||
|
||||
|
||||
class Logger:
|
||||
|
|
|
@ -206,9 +206,10 @@ class TrainTFMultiGPU:
|
|||
self.per_device_batch_size)
|
||||
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
logger.debug("{} {}".format(i,
|
||||
averaged(iter_extra_fetches)))
|
||||
fetches[policy_id] = averaged(iter_extra_fetches)
|
||||
if logger.getEffectiveLevel() <= logging.DEBUG:
|
||||
avg = averaged(iter_extra_fetches)
|
||||
logger.debug("{} {}".format(i, avg))
|
||||
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
|
||||
|
||||
load_timer.push_units_processed(samples.count)
|
||||
learn_timer.push_units_processed(samples.count)
|
||||
|
|
|
@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def averaged(kv):
|
||||
def averaged(kv, axis=None):
|
||||
"""Average the value lists of a dictionary.
|
||||
|
||||
For non-scalar values, we simply pick the first value.
|
||||
|
@ -27,7 +27,7 @@ def averaged(kv):
|
|||
out = {}
|
||||
for k, v in kv.items():
|
||||
if v[0] is not None and not isinstance(v[0], dict):
|
||||
out[k] = np.mean(v)
|
||||
out[k] = np.mean(v, axis=axis)
|
||||
else:
|
||||
out[k] = v[0]
|
||||
return out
|
||||
|
|
Loading…
Add table
Reference in a new issue