mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Test] Use Ray client in XGBoost train_small release test (#16319)
This commit is contained in:
parent
0f87eca3e9
commit
2f7d535253
3 changed files with 22 additions and 12 deletions
|
@ -1,4 +1,4 @@
|
|||
base_image: "anyscale/ray:1.2.0"
|
||||
base_image: "anyscale/ray-ml:pinned-nightly-py37"
|
||||
env_vars: {}
|
||||
debian_packages:
|
||||
- curl
|
||||
|
|
|
@ -16,7 +16,12 @@ from xgboost_ray import RayParams
|
|||
from ray.util.xgboost.release_test_util import train_ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(address="auto")
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_small")
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.client(address=addr).job_name(job_name).connect()
|
||||
else:
|
||||
ray.init(address="auto")
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
|
@ -25,17 +30,21 @@ if __name__ == "__main__":
|
|||
cpus_per_actor=4,
|
||||
gpus_per_actor=0)
|
||||
|
||||
@ray.remote
|
||||
def train():
|
||||
train_ray(
|
||||
path="/data/classification.parquet",
|
||||
num_workers=4,
|
||||
num_boost_rounds=100,
|
||||
num_files=25,
|
||||
regression=False,
|
||||
use_gpu=False,
|
||||
ray_params=ray_params,
|
||||
xgboost_params=None,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
train_ray(
|
||||
path="/data/classification.parquet",
|
||||
num_workers=4,
|
||||
num_boost_rounds=100,
|
||||
num_files=25,
|
||||
regression=False,
|
||||
use_gpu=False,
|
||||
ray_params=ray_params,
|
||||
xgboost_params=None,
|
||||
)
|
||||
ray.get(train.remote())
|
||||
taken = time.time() - start
|
||||
|
||||
result = {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
compute_template: tpl_cpu_small.yaml
|
||||
|
||||
run:
|
||||
use_connect: True
|
||||
timeout: 600
|
||||
prepare: python wait_cluster.py 4 600
|
||||
script: python workloads/train_small.py
|
||||
|
|
Loading…
Add table
Reference in a new issue