[tune] migrate xgboost callback api (#12745)

* Migrate to new-style xgboost callbacks

* Fix flaky progress reporter test

* Fix import error

* Take last value (not first)
This commit is contained in:
Kai Fricke 2020-12-12 10:42:20 +01:00 committed by GitHub
parent 42c70be073
commit 905652cdd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 16 deletions

View file

@ -94,6 +94,9 @@ for mod_name in MOCK_MODULES:
sys.modules["tensorflow"].VERSION = "9.9.9"
sys.modules["tensorflow.keras.callbacks"] = ChildClassMock()
sys.modules["pytorch_lightning"] = ChildClassMock()
sys.modules["xgboost"] = ChildClassMock()
sys.modules["xgboost.core"] = ChildClassMock()
sys.modules["xgboost.callback"] = ChildClassMock()
class SimpleClass(object):

View file

@ -1,14 +1,29 @@
from typing import Dict, List, Union
from collections import OrderedDict
from ray import tune
import os
from ray.tune.utils import flatten_dict
from xgboost.core import Booster
class TuneCallback:
try:
from xgboost.callback import TrainingCallback
except ImportError:
class TrainingCallback:
pass
class TuneCallback(TrainingCallback):
"""Base class for Tune's XGBoost callbacks."""
pass
def __call__(self, env):
"""Compatibility with xgboost<1.3"""
return self.after_iteration(env.model, env.iteration,
env.evaluation_result_list)
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
raise NotImplementedError
@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback):
metrics = [metrics]
self._metrics = metrics
def _get_report_dict(self, env):
# Only one worker should report to Tune
result_dict = dict(env.evaluation_result_list)
def _get_report_dict(self, evals_log):
if isinstance(evals_log, OrderedDict):
# xgboost>=1.3
result_dict = flatten_dict(evals_log, delimiter="-")
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]
else:
# xgboost<1.3
result_dict = dict(evals_log)
if not self._metrics:
report_dict = result_dict
else:
@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback):
report_dict[key] = result_dict[metric]
return report_dict
def __call__(self, env):
report_dict = self._get_report_dict(env)
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
report_dict = self._get_report_dict(evals_log)
tune.report(**report_dict)
@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback):
self._frequency = frequency
@staticmethod
def _create_checkpoint(env, filename: str, frequency: int):
if env.iteration % frequency > 0:
def _create_checkpoint(model: Booster, epoch: int, filename: str,
frequency: int):
if epoch % frequency > 0:
return
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
env.model.save_model(os.path.join(checkpoint_dir, filename))
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
model.save_model(os.path.join(checkpoint_dir, filename))
def __call__(self, env):
self._create_checkpoint(env, self._filename, self._frequency)
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._create_checkpoint(model, epoch, self._filename, self._frequency)
class TuneReportCheckpointCallback(TuneCallback):
@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback):
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
self._report = self._report_callbacks_cls(metrics)
def __call__(self, env):
self._checkpoint(env)
self._report(env)
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._checkpoint.after_iteration(model, epoch, evals_log)
self._report.after_iteration(model, epoch, evals_log)

View file

@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+
VERBOSE_CMD = """from ray import tune
import random
import numpy as np
import time
def train(config):
if config["do"] == "complete":
time.sleep(0.1)
tune.report(acc=5, done=True)
elif config["do"] == "once":
time.sleep(0.5)
tune.report(6)
else:
time.sleep(1.0)
tune.report(acc=7)
tune.report(acc=8)