mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] move evaluation to trainer.step() such that the result is properly logged (#12708)
This commit is contained in:
parent
964689b280
commit
b4702de1c2
2 changed files with 12 additions and 8 deletions
|
@ -535,14 +535,6 @@ class Trainer(Trainable):
|
|||
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
|
||||
self._sync_filters_if_needed(self.workers)
|
||||
|
||||
if self.config["evaluation_interval"] == 1 or (
|
||||
self._iteration > 0 and self.config["evaluation_interval"]
|
||||
and self._iteration % self.config["evaluation_interval"] == 0):
|
||||
evaluation_metrics = self._evaluate()
|
||||
assert isinstance(evaluation_metrics, dict), \
|
||||
"_evaluate() needs to return a dict."
|
||||
result.update(evaluation_metrics)
|
||||
|
||||
return result
|
||||
|
||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||
|
|
|
@ -146,6 +146,18 @@ def build_trainer(
|
|||
@override(Trainer)
|
||||
def step(self):
|
||||
res = next(self.train_exec_impl)
|
||||
|
||||
# self._iteration gets incremented after this function returns,
|
||||
# meaning that e. g. the first time this function is called,
|
||||
# self._iteration will be 0. We check `self._iteration+1` in the
|
||||
# if-statement below to reflect that the first training iteration
|
||||
# is already over.
|
||||
if (self.config["evaluation_interval"] and (self._iteration + 1) %
|
||||
self.config["evaluation_interval"] == 0):
|
||||
evaluation_metrics = self._evaluate()
|
||||
assert isinstance(evaluation_metrics, dict), \
|
||||
"_evaluate() needs to return a dict."
|
||||
res.update(evaluation_metrics)
|
||||
return res
|
||||
|
||||
@override(Trainer)
|
||||
|
|
Loading…
Add table
Reference in a new issue