mirror of
https://github.com/vale981/ray
synced 2025-03-09 21:06:39 -04:00
27 lines
642 B
Python
27 lines
642 B
Python
![]() |
"""Distributed XGBoost API test
|
||
|
|
||
|
This test runs unit tests on a distributed cluster. This will confirm that
|
||
|
XGBoost API features like custom metrics/objectives work with remote
|
||
|
trainables.
|
||
|
|
||
|
Test owner: krfricke
|
||
|
|
||
|
Acceptance criteria: Unit tests should pass (requires pytest).
|
||
|
"""
|
||
|
|
||
|
import ray
|
||
|
|
||
|
from xgboost_ray.tests.test_xgboost_api import XGBoostAPITest
|
||
|
|
||
|
|
||
|
class XGBoostDistributedAPITest(XGBoostAPITest):
|
||
|
def _init_ray(self):
|
||
|
if not ray.is_initialized():
|
||
|
ray.init(address="auto")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
import sys
|
||
|
sys.exit(pytest.main(["-v", f"{__file__}::XGBoostDistributedAPITest"]))
|