[RLlib] move evaluation to trainer.step() such that the result is properly logged (#12708)

This commit is contained in:
Maltimore 2021-01-25 12:56:00 +01:00 committed by GitHub
parent 964689b280
commit b4702de1c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 8 deletions

View file

@ -535,14 +535,6 @@ class Trainer(Trainable):
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
self._sync_filters_if_needed(self.workers)
if self.config["evaluation_interval"] == 1 or (
self._iteration > 0 and self.config["evaluation_interval"]
and self._iteration % self.config["evaluation_interval"] == 0):
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)
return result
def _sync_filters_if_needed(self, workers: WorkerSet):

View file

@ -146,6 +146,18 @@ def build_trainer(
@override(Trainer)
def step(self):
res = next(self.train_exec_impl)
# self._iteration gets incremented after this function returns,
# meaning that e. g. the first time this function is called,
# self._iteration will be 0. We check `self._iteration+1` in the
# if-statement below to reflect that the first training iteration
# is already over.
if (self.config["evaluation_interval"] and (self._iteration + 1) %
self.config["evaluation_interval"] == 0):
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
res.update(evaluation_metrics)
return res
@override(Trainer)