mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
4633d81c39
commit
c9435cad43
3 changed files with 7 additions and 6 deletions
|
@ -18,7 +18,7 @@ from ray.tune.utils import flatten_dict
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
tf = None
|
tf = None
|
||||||
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]
|
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
|
|
|
@ -206,9 +206,10 @@ class TrainTFMultiGPU:
|
||||||
self.per_device_batch_size)
|
self.per_device_batch_size)
|
||||||
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
||||||
iter_extra_fetches[k].append(v)
|
iter_extra_fetches[k].append(v)
|
||||||
logger.debug("{} {}".format(i,
|
if logger.getEffectiveLevel() <= logging.DEBUG:
|
||||||
averaged(iter_extra_fetches)))
|
avg = averaged(iter_extra_fetches)
|
||||||
fetches[policy_id] = averaged(iter_extra_fetches)
|
logger.debug("{} {}".format(i, avg))
|
||||||
|
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
|
||||||
|
|
||||||
load_timer.push_units_processed(samples.count)
|
load_timer.push_units_processed(samples.count)
|
||||||
learn_timer.push_units_processed(samples.count)
|
learn_timer.push_units_processed(samples.count)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def averaged(kv):
|
def averaged(kv, axis=None):
|
||||||
"""Average the value lists of a dictionary.
|
"""Average the value lists of a dictionary.
|
||||||
|
|
||||||
For non-scalar values, we simply pick the first value.
|
For non-scalar values, we simply pick the first value.
|
||||||
|
@ -27,7 +27,7 @@ def averaged(kv):
|
||||||
out = {}
|
out = {}
|
||||||
for k, v in kv.items():
|
for k, v in kv.items():
|
||||||
if v[0] is not None and not isinstance(v[0], dict):
|
if v[0] is not None and not isinstance(v[0], dict):
|
||||||
out[k] = np.mean(v)
|
out[k] = np.mean(v, axis=axis)
|
||||||
else:
|
else:
|
||||||
out[k] = v[0]
|
out[k] = v[0]
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Add table
Reference in a new issue