mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] move evaluation to trainer.step() such that the result is properly logged (#12708)
This commit is contained in:
parent
964689b280
commit
b4702de1c2
2 changed files with 12 additions and 8 deletions
|
@ -535,14 +535,6 @@ class Trainer(Trainable):
|
||||||
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
|
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
|
||||||
self._sync_filters_if_needed(self.workers)
|
self._sync_filters_if_needed(self.workers)
|
||||||
|
|
||||||
if self.config["evaluation_interval"] == 1 or (
|
|
||||||
self._iteration > 0 and self.config["evaluation_interval"]
|
|
||||||
and self._iteration % self.config["evaluation_interval"] == 0):
|
|
||||||
evaluation_metrics = self._evaluate()
|
|
||||||
assert isinstance(evaluation_metrics, dict), \
|
|
||||||
"_evaluate() needs to return a dict."
|
|
||||||
result.update(evaluation_metrics)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||||
|
|
|
@ -146,6 +146,18 @@ def build_trainer(
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def step(self):
|
def step(self):
|
||||||
res = next(self.train_exec_impl)
|
res = next(self.train_exec_impl)
|
||||||
|
|
||||||
|
# self._iteration gets incremented after this function returns,
|
||||||
|
# meaning that e. g. the first time this function is called,
|
||||||
|
# self._iteration will be 0. We check `self._iteration+1` in the
|
||||||
|
# if-statement below to reflect that the first training iteration
|
||||||
|
# is already over.
|
||||||
|
if (self.config["evaluation_interval"] and (self._iteration + 1) %
|
||||||
|
self.config["evaluation_interval"] == 0):
|
||||||
|
evaluation_metrics = self._evaluate()
|
||||||
|
assert isinstance(evaluation_metrics, dict), \
|
||||||
|
"_evaluate() needs to return a dict."
|
||||||
|
res.update(evaluation_metrics)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
|
|
Loading…
Add table
Reference in a new issue