mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] migrate xgboost callback api (#12745)
* Migrate to new-style xgboost callbacks * Fix flaky progress reporter test * Fix import error * Take last value (not first)
This commit is contained in:
parent
42c70be073
commit
905652cdd6
3 changed files with 46 additions and 16 deletions
|
@ -94,6 +94,9 @@ for mod_name in MOCK_MODULES:
|
|||
sys.modules["tensorflow"].VERSION = "9.9.9"
|
||||
sys.modules["tensorflow.keras.callbacks"] = ChildClassMock()
|
||||
sys.modules["pytorch_lightning"] = ChildClassMock()
|
||||
sys.modules["xgboost"] = ChildClassMock()
|
||||
sys.modules["xgboost.core"] = ChildClassMock()
|
||||
sys.modules["xgboost.callback"] = ChildClassMock()
|
||||
|
||||
|
||||
class SimpleClass(object):
|
||||
|
|
|
@ -1,14 +1,29 @@
|
|||
from typing import Dict, List, Union
|
||||
from collections import OrderedDict
|
||||
from ray import tune
|
||||
|
||||
import os
|
||||
|
||||
from ray.tune.utils import flatten_dict
|
||||
from xgboost.core import Booster
|
||||
|
||||
class TuneCallback:
|
||||
try:
|
||||
from xgboost.callback import TrainingCallback
|
||||
except ImportError:
|
||||
|
||||
class TrainingCallback:
|
||||
pass
|
||||
|
||||
|
||||
class TuneCallback(TrainingCallback):
|
||||
"""Base class for Tune's XGBoost callbacks."""
|
||||
pass
|
||||
|
||||
def __call__(self, env):
|
||||
"""Compatibility with xgboost<1.3"""
|
||||
return self.after_iteration(env.model, env.iteration,
|
||||
env.evaluation_result_list)
|
||||
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback):
|
|||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def _get_report_dict(self, env):
|
||||
# Only one worker should report to Tune
|
||||
result_dict = dict(env.evaluation_result_list)
|
||||
def _get_report_dict(self, evals_log):
|
||||
if isinstance(evals_log, OrderedDict):
|
||||
# xgboost>=1.3
|
||||
result_dict = flatten_dict(evals_log, delimiter="-")
|
||||
for k in list(result_dict):
|
||||
result_dict[k] = result_dict[k][-1]
|
||||
else:
|
||||
# xgboost<1.3
|
||||
result_dict = dict(evals_log)
|
||||
if not self._metrics:
|
||||
report_dict = result_dict
|
||||
else:
|
||||
|
@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback):
|
|||
report_dict[key] = result_dict[metric]
|
||||
return report_dict
|
||||
|
||||
def __call__(self, env):
|
||||
report_dict = self._get_report_dict(env)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
|
||||
report_dict = self._get_report_dict(evals_log)
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
|
@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback):
|
|||
self._frequency = frequency
|
||||
|
||||
@staticmethod
|
||||
def _create_checkpoint(env, filename: str, frequency: int):
|
||||
if env.iteration % frequency > 0:
|
||||
def _create_checkpoint(model: Booster, epoch: int, filename: str,
|
||||
frequency: int):
|
||||
if epoch % frequency > 0:
|
||||
return
|
||||
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
|
||||
env.model.save_model(os.path.join(checkpoint_dir, filename))
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
model.save_model(os.path.join(checkpoint_dir, filename))
|
||||
|
||||
def __call__(self, env):
|
||||
self._create_checkpoint(env, self._filename, self._frequency)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
self._create_checkpoint(model, epoch, self._filename, self._frequency)
|
||||
|
||||
|
||||
class TuneReportCheckpointCallback(TuneCallback):
|
||||
|
@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback):
|
|||
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
|
||||
self._report = self._report_callbacks_cls(metrics)
|
||||
|
||||
def __call__(self, env):
|
||||
self._checkpoint(env)
|
||||
self._report(env)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
self._checkpoint.after_iteration(model, epoch, evals_log)
|
||||
self._report.after_iteration(model, epoch, evals_log)
|
||||
|
|
|
@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+
|
|||
VERBOSE_CMD = """from ray import tune
|
||||
import random
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def train(config):
|
||||
if config["do"] == "complete":
|
||||
time.sleep(0.1)
|
||||
tune.report(acc=5, done=True)
|
||||
elif config["do"] == "once":
|
||||
time.sleep(0.5)
|
||||
tune.report(6)
|
||||
else:
|
||||
time.sleep(1.0)
|
||||
tune.report(acc=7)
|
||||
tune.report(acc=8)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue