mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[tune] migrate xgboost callback api (#12745)
* Migrate to new-style xgboost callbacks * Fix flaky progress reporter test * Fix import error * Take last value (not first)
This commit is contained in:
parent
42c70be073
commit
905652cdd6
3 changed files with 46 additions and 16 deletions
|
@ -94,6 +94,9 @@ for mod_name in MOCK_MODULES:
|
||||||
sys.modules["tensorflow"].VERSION = "9.9.9"
|
sys.modules["tensorflow"].VERSION = "9.9.9"
|
||||||
sys.modules["tensorflow.keras.callbacks"] = ChildClassMock()
|
sys.modules["tensorflow.keras.callbacks"] = ChildClassMock()
|
||||||
sys.modules["pytorch_lightning"] = ChildClassMock()
|
sys.modules["pytorch_lightning"] = ChildClassMock()
|
||||||
|
sys.modules["xgboost"] = ChildClassMock()
|
||||||
|
sys.modules["xgboost.core"] = ChildClassMock()
|
||||||
|
sys.modules["xgboost.callback"] = ChildClassMock()
|
||||||
|
|
||||||
|
|
||||||
class SimpleClass(object):
|
class SimpleClass(object):
|
||||||
|
|
|
@ -1,14 +1,29 @@
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
from collections import OrderedDict
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from ray.tune.utils import flatten_dict
|
||||||
|
from xgboost.core import Booster
|
||||||
|
|
||||||
class TuneCallback:
|
try:
|
||||||
|
from xgboost.callback import TrainingCallback
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class TrainingCallback:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TuneCallback(TrainingCallback):
|
||||||
"""Base class for Tune's XGBoost callbacks."""
|
"""Base class for Tune's XGBoost callbacks."""
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, env):
|
def __call__(self, env):
|
||||||
|
"""Compatibility with xgboost<1.3"""
|
||||||
|
return self.after_iteration(env.model, env.iteration,
|
||||||
|
env.evaluation_result_list)
|
||||||
|
|
||||||
|
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback):
|
||||||
metrics = [metrics]
|
metrics = [metrics]
|
||||||
self._metrics = metrics
|
self._metrics = metrics
|
||||||
|
|
||||||
def _get_report_dict(self, env):
|
def _get_report_dict(self, evals_log):
|
||||||
# Only one worker should report to Tune
|
if isinstance(evals_log, OrderedDict):
|
||||||
result_dict = dict(env.evaluation_result_list)
|
# xgboost>=1.3
|
||||||
|
result_dict = flatten_dict(evals_log, delimiter="-")
|
||||||
|
for k in list(result_dict):
|
||||||
|
result_dict[k] = result_dict[k][-1]
|
||||||
|
else:
|
||||||
|
# xgboost<1.3
|
||||||
|
result_dict = dict(evals_log)
|
||||||
if not self._metrics:
|
if not self._metrics:
|
||||||
report_dict = result_dict
|
report_dict = result_dict
|
||||||
else:
|
else:
|
||||||
|
@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback):
|
||||||
report_dict[key] = result_dict[metric]
|
report_dict[key] = result_dict[metric]
|
||||||
return report_dict
|
return report_dict
|
||||||
|
|
||||||
def __call__(self, env):
|
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||||
report_dict = self._get_report_dict(env)
|
|
||||||
|
report_dict = self._get_report_dict(evals_log)
|
||||||
tune.report(**report_dict)
|
tune.report(**report_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback):
|
||||||
self._frequency = frequency
|
self._frequency = frequency
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_checkpoint(env, filename: str, frequency: int):
|
def _create_checkpoint(model: Booster, epoch: int, filename: str,
|
||||||
if env.iteration % frequency > 0:
|
frequency: int):
|
||||||
|
if epoch % frequency > 0:
|
||||||
return
|
return
|
||||||
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
|
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||||
env.model.save_model(os.path.join(checkpoint_dir, filename))
|
model.save_model(os.path.join(checkpoint_dir, filename))
|
||||||
|
|
||||||
def __call__(self, env):
|
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||||
self._create_checkpoint(env, self._filename, self._frequency)
|
self._create_checkpoint(model, epoch, self._filename, self._frequency)
|
||||||
|
|
||||||
|
|
||||||
class TuneReportCheckpointCallback(TuneCallback):
|
class TuneReportCheckpointCallback(TuneCallback):
|
||||||
|
@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback):
|
||||||
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
|
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
|
||||||
self._report = self._report_callbacks_cls(metrics)
|
self._report = self._report_callbacks_cls(metrics)
|
||||||
|
|
||||||
def __call__(self, env):
|
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||||
self._checkpoint(env)
|
self._checkpoint.after_iteration(model, epoch, evals_log)
|
||||||
self._report(env)
|
self._report.after_iteration(model, epoch, evals_log)
|
||||||
|
|
|
@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+
|
||||||
VERBOSE_CMD = """from ray import tune
|
VERBOSE_CMD = """from ray import tune
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def train(config):
|
def train(config):
|
||||||
if config["do"] == "complete":
|
if config["do"] == "complete":
|
||||||
|
time.sleep(0.1)
|
||||||
tune.report(acc=5, done=True)
|
tune.report(acc=5, done=True)
|
||||||
elif config["do"] == "once":
|
elif config["do"] == "once":
|
||||||
|
time.sleep(0.5)
|
||||||
tune.report(6)
|
tune.report(6)
|
||||||
else:
|
else:
|
||||||
|
time.sleep(1.0)
|
||||||
tune.report(acc=7)
|
tune.report(acc=7)
|
||||||
tune.report(acc=8)
|
tune.report(acc=8)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue