diff --git a/doc/source/conf.py b/doc/source/conf.py index 5df6e8be1..c69f73760 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -94,6 +94,9 @@ for mod_name in MOCK_MODULES: sys.modules["tensorflow"].VERSION = "9.9.9" sys.modules["tensorflow.keras.callbacks"] = ChildClassMock() sys.modules["pytorch_lightning"] = ChildClassMock() +sys.modules["xgboost"] = ChildClassMock() +sys.modules["xgboost.core"] = ChildClassMock() +sys.modules["xgboost.callback"] = ChildClassMock() class SimpleClass(object): diff --git a/python/ray/tune/integration/xgboost.py b/python/ray/tune/integration/xgboost.py index fa40fa60a..c3d6c02e9 100644 --- a/python/ray/tune/integration/xgboost.py +++ b/python/ray/tune/integration/xgboost.py @@ -1,14 +1,29 @@ from typing import Dict, List, Union +from collections import OrderedDict from ray import tune import os +from ray.tune.utils import flatten_dict +from xgboost.core import Booster -class TuneCallback: +try: + from xgboost.callback import TrainingCallback +except ImportError: + + class TrainingCallback: + pass + + +class TuneCallback(TrainingCallback): """Base class for Tune's XGBoost callbacks.""" - pass def __call__(self, env): + """Compatibility with xgboost<1.3""" + return self.after_iteration(env.model, env.iteration, + env.evaluation_result_list) + + def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): raise NotImplementedError @@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback): metrics = [metrics] self._metrics = metrics - def _get_report_dict(self, env): - # Only one worker should report to Tune - result_dict = dict(env.evaluation_result_list) + def _get_report_dict(self, evals_log): + if isinstance(evals_log, OrderedDict): + # xgboost>=1.3 + result_dict = flatten_dict(evals_log, delimiter="-") + for k in list(result_dict): + result_dict[k] = result_dict[k][-1] + else: + # xgboost<1.3 + result_dict = dict(evals_log) if not self._metrics: report_dict = result_dict else: @@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback): report_dict[key] = result_dict[metric] return report_dict - def __call__(self, env): - report_dict = self._get_report_dict(env) + def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): + + report_dict = self._get_report_dict(evals_log) tune.report(**report_dict) @@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback): self._frequency = frequency @staticmethod - def _create_checkpoint(env, filename: str, frequency: int): - if env.iteration % frequency > 0: + def _create_checkpoint(model: Booster, epoch: int, filename: str, + frequency: int): + if epoch % frequency > 0: return - with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir: - env.model.save_model(os.path.join(checkpoint_dir, filename)) + with tune.checkpoint_dir(step=epoch) as checkpoint_dir: + model.save_model(os.path.join(checkpoint_dir, filename)) - def __call__(self, env): - self._create_checkpoint(env, self._filename, self._frequency) + def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): + self._create_checkpoint(model, epoch, self._filename, self._frequency) class TuneReportCheckpointCallback(TuneCallback): @@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback): self._checkpoint = self._checkpoint_callback_cls(filename, frequency) self._report = self._report_callbacks_cls(metrics) - def __call__(self, env): - self._checkpoint(env) - self._report(env) + def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): + self._checkpoint.after_iteration(model, epoch, evals_log) + self._report.after_iteration(model, epoch, evals_log) diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 0df1d89dd..3a65799fc 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+ VERBOSE_CMD = """from ray import tune import random import numpy as np +import time def train(config): if config["do"] == "complete": + time.sleep(0.1) tune.report(acc=5, done=True) elif config["do"] == "once": + time.sleep(0.5) tune.report(6) else: + time.sleep(1.0) tune.report(acc=7) tune.report(acc=8)