diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 45e2c1539..e1d2ae8ae 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -292,13 +292,18 @@ class Agent(Trainable): logger.debug("synchronized filters: {}".format( self.local_evaluator.filters)) + return result + + @override(Trainable) + def _log_result(self, result): if self.config["callbacks"].get("on_train_result"): self.config["callbacks"]["on_train_result"]({ "agent": self, "result": result, }) - - return result + # log after the callback is invoked, so that the user has a chance + # to mutate the result + Trainable._log_result(self, result) @override(Trainable) def _setup(self, config): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 581085fd1..7bcf76858 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -202,7 +202,7 @@ class Trainable(object): timesteps_since_restore=self._timesteps_since_restore, iterations_since_restore=self._iterations_since_restore) - self._result_logger.on_result(result) + self._log_result(result) return result @@ -415,6 +415,15 @@ class Trainable(object): """ pass + def _log_result(self, result): + """Subclasses can optionally override this to customize logging. + + Args: + result (dict): Training result returned by _train(). + """ + + self._result_logger.on_result(result) + def _stop(self): """Subclasses should override this for any cleanup on stop.""" pass