mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
[rllib] on_train_result results do not get logged (#3865)
This commit is contained in:
parent
e0f82fd260
commit
0f81bc9a33
2 changed files with 17 additions and 3 deletions
|
@ -292,13 +292,18 @@ class Agent(Trainable):
|
||||||
logger.debug("synchronized filters: {}".format(
|
logger.debug("synchronized filters: {}".format(
|
||||||
self.local_evaluator.filters))
|
self.local_evaluator.filters))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@override(Trainable)
|
||||||
|
def _log_result(self, result):
|
||||||
if self.config["callbacks"].get("on_train_result"):
|
if self.config["callbacks"].get("on_train_result"):
|
||||||
self.config["callbacks"]["on_train_result"]({
|
self.config["callbacks"]["on_train_result"]({
|
||||||
"agent": self,
|
"agent": self,
|
||||||
"result": result,
|
"result": result,
|
||||||
})
|
})
|
||||||
|
# log after the callback is invoked, so that the user has a chance
|
||||||
return result
|
# to mutate the result
|
||||||
|
Trainable._log_result(self, result)
|
||||||
|
|
||||||
@override(Trainable)
|
@override(Trainable)
|
||||||
def _setup(self, config):
|
def _setup(self, config):
|
||||||
|
|
|
@ -202,7 +202,7 @@ class Trainable(object):
|
||||||
timesteps_since_restore=self._timesteps_since_restore,
|
timesteps_since_restore=self._timesteps_since_restore,
|
||||||
iterations_since_restore=self._iterations_since_restore)
|
iterations_since_restore=self._iterations_since_restore)
|
||||||
|
|
||||||
self._result_logger.on_result(result)
|
self._log_result(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -415,6 +415,15 @@ class Trainable(object):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _log_result(self, result):
|
||||||
|
"""Subclasses can optionally override this to customize logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result (dict): Training result returned by _train().
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._result_logger.on_result(result)
|
||||||
|
|
||||||
def _stop(self):
|
def _stop(self):
|
||||||
"""Subclasses should override this for any cleanup on stop."""
|
"""Subclasses should override this for any cleanup on stop."""
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue