mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Roll forward to run train_small in client mode. (#16610)
This commit is contained in:
parent
c95dea51e9
commit
48599aef9e
3 changed files with 27 additions and 12 deletions
|
@ -8,6 +8,7 @@ python:
|
|||
- pytest
|
||||
- xgboost_ray
|
||||
- petastorm
|
||||
- tblib
|
||||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
|
|
|
@ -16,8 +16,15 @@ from xgboost_ray import RayParams
|
|||
from ray.util.xgboost.release_test_util import train_ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_small")
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.client(address=addr).job_name(job_name).connect()
|
||||
else:
|
||||
ray.init(address="auto")
|
||||
|
||||
output = os.environ["TEST_OUTPUT_JSON"]
|
||||
state = os.environ["TEST_STATE_JSON"]
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
max_actor_restarts=2,
|
||||
|
@ -26,6 +33,11 @@ if __name__ == "__main__":
|
|||
gpus_per_actor=0)
|
||||
|
||||
start = time.time()
|
||||
|
||||
@ray.remote
|
||||
def train():
|
||||
os.environ["TEST_OUTPUT_JSON"] = output
|
||||
os.environ["TEST_STATE_JSON"] = state
|
||||
train_ray(
|
||||
path="/data/classification.parquet",
|
||||
num_workers=4,
|
||||
|
@ -36,6 +48,8 @@ if __name__ == "__main__":
|
|||
ray_params=ray_params,
|
||||
xgboost_params=None,
|
||||
)
|
||||
|
||||
ray.get(train.remote())
|
||||
taken = time.time() - start
|
||||
|
||||
result = {
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
compute_template: tpl_cpu_small.yaml
|
||||
|
||||
run:
|
||||
# use_connect: True
|
||||
use_connect: True
|
||||
timeout: 600
|
||||
prepare: python wait_cluster.py 4 600
|
||||
script: python workloads/train_small.py
|
||||
|
|
Loading…
Add table
Reference in a new issue