From 0f81bc9a33d02085311b5b04f07cf91323a8c362 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 1 Feb 2019 20:32:07 -0800 Subject: [PATCH] [rllib] on_train_result results do not get logged (#3865) --- python/ray/rllib/agents/agent.py | 9 +++++++-- python/ray/tune/trainable.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 45e2c1539..e1d2ae8ae 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -292,13 +292,18 @@ class Agent(Trainable): logger.debug("synchronized filters: {}".format( self.local_evaluator.filters)) + return result + + @override(Trainable) + def _log_result(self, result): if self.config["callbacks"].get("on_train_result"): self.config["callbacks"]["on_train_result"]({ "agent": self, "result": result, }) - - return result + # log after the callback is invoked, so that the user has a chance + # to mutate the result + Trainable._log_result(self, result) @override(Trainable) def _setup(self, config): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 581085fd1..7bcf76858 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -202,7 +202,7 @@ class Trainable(object): timesteps_since_restore=self._timesteps_since_restore, iterations_since_restore=self._iterations_since_restore) - self._result_logger.on_result(result) + self._log_result(result) return result @@ -415,6 +415,15 @@ class Trainable(object): """ pass + def _log_result(self, result): + """Subclasses can optionally override this to customize logging. + + Args: + result (dict): Training result returned by _train(). + """ + + self._result_logger.on_result(result) + def _stop(self): """Subclasses should override this for any cleanup on stop.""" pass