[tune] migrate xgboost callback api (#12745)

* Migrate to new-style xgboost callbacks

* Fix flaky progress reporter test

* Fix import error

* Take last value (not first)
This commit is contained in:
Kai Fricke 2020-12-12 10:42:20 +01:00 committed by GitHub
parent 42c70be073
commit 905652cdd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 16 deletions

View file

@ -94,6 +94,9 @@ for mod_name in MOCK_MODULES:
sys.modules["tensorflow"].VERSION = "9.9.9" sys.modules["tensorflow"].VERSION = "9.9.9"
sys.modules["tensorflow.keras.callbacks"] = ChildClassMock() sys.modules["tensorflow.keras.callbacks"] = ChildClassMock()
sys.modules["pytorch_lightning"] = ChildClassMock() sys.modules["pytorch_lightning"] = ChildClassMock()
sys.modules["xgboost"] = ChildClassMock()
sys.modules["xgboost.core"] = ChildClassMock()
sys.modules["xgboost.callback"] = ChildClassMock()
class SimpleClass(object): class SimpleClass(object):

View file

@ -1,14 +1,29 @@
from typing import Dict, List, Union from typing import Dict, List, Union
from collections import OrderedDict
from ray import tune from ray import tune
import os import os
from ray.tune.utils import flatten_dict
from xgboost.core import Booster
class TuneCallback: try:
from xgboost.callback import TrainingCallback
except ImportError:
class TrainingCallback:
pass
class TuneCallback(TrainingCallback):
"""Base class for Tune's XGBoost callbacks.""" """Base class for Tune's XGBoost callbacks."""
pass
def __call__(self, env): def __call__(self, env):
"""Compatibility with xgboost<1.3"""
return self.after_iteration(env.model, env.iteration,
env.evaluation_result_list)
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
raise NotImplementedError raise NotImplementedError
@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback):
metrics = [metrics] metrics = [metrics]
self._metrics = metrics self._metrics = metrics
def _get_report_dict(self, env): def _get_report_dict(self, evals_log):
# Only one worker should report to Tune if isinstance(evals_log, OrderedDict):
result_dict = dict(env.evaluation_result_list) # xgboost>=1.3
result_dict = flatten_dict(evals_log, delimiter="-")
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]
else:
# xgboost<1.3
result_dict = dict(evals_log)
if not self._metrics: if not self._metrics:
report_dict = result_dict report_dict = result_dict
else: else:
@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback):
report_dict[key] = result_dict[metric] report_dict[key] = result_dict[metric]
return report_dict return report_dict
def __call__(self, env): def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
report_dict = self._get_report_dict(env)
report_dict = self._get_report_dict(evals_log)
tune.report(**report_dict) tune.report(**report_dict)
@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback):
self._frequency = frequency self._frequency = frequency
@staticmethod @staticmethod
def _create_checkpoint(env, filename: str, frequency: int): def _create_checkpoint(model: Booster, epoch: int, filename: str,
if env.iteration % frequency > 0: frequency: int):
if epoch % frequency > 0:
return return
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir: with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
env.model.save_model(os.path.join(checkpoint_dir, filename)) model.save_model(os.path.join(checkpoint_dir, filename))
def __call__(self, env): def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._create_checkpoint(env, self._filename, self._frequency) self._create_checkpoint(model, epoch, self._filename, self._frequency)
class TuneReportCheckpointCallback(TuneCallback): class TuneReportCheckpointCallback(TuneCallback):
@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback):
self._checkpoint = self._checkpoint_callback_cls(filename, frequency) self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
self._report = self._report_callbacks_cls(metrics) self._report = self._report_callbacks_cls(metrics)
def __call__(self, env): def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._checkpoint(env) self._checkpoint.after_iteration(model, epoch, evals_log)
self._report(env) self._report.after_iteration(model, epoch, evals_log)

View file

@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+
VERBOSE_CMD = """from ray import tune VERBOSE_CMD = """from ray import tune
import random import random
import numpy as np import numpy as np
import time
def train(config): def train(config):
if config["do"] == "complete": if config["do"] == "complete":
time.sleep(0.1)
tune.report(acc=5, done=True) tune.report(acc=5, done=True)
elif config["do"] == "once": elif config["do"] == "once":
time.sleep(0.5)
tune.report(6) tune.report(6)
else: else:
time.sleep(1.0)
tune.report(acc=7) tune.report(acc=7)
tune.report(acc=8) tune.report(acc=8)