Roll forward to run train_small in client mode. (#16610)

This commit is contained in:
mwtian 2021-06-23 00:52:08 -07:00 committed by GitHub
parent c95dea51e9
commit 48599aef9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 12 deletions

View file

@ -8,6 +8,7 @@ python:
- pytest
- xgboost_ray
- petastorm
- tblib
conda_packages: []
post_build_cmds:

View file

@ -16,8 +16,15 @@ from xgboost_ray import RayParams
from ray.util.xgboost.release_test_util import train_ray
if __name__ == "__main__":
addr = os.environ.get("RAY_ADDRESS")
job_name = os.environ.get("RAY_JOB_NAME", "train_small")
if addr.startswith("anyscale://"):
ray.client(address=addr).job_name(job_name).connect()
else:
ray.init(address="auto")
output = os.environ["TEST_OUTPUT_JSON"]
state = os.environ["TEST_STATE_JSON"]
ray_params = RayParams(
elastic_training=False,
max_actor_restarts=2,
@ -26,6 +33,11 @@ if __name__ == "__main__":
gpus_per_actor=0)
start = time.time()
@ray.remote
def train():
os.environ["TEST_OUTPUT_JSON"] = output
os.environ["TEST_STATE_JSON"] = state
train_ray(
path="/data/classification.parquet",
num_workers=4,
@ -36,6 +48,8 @@ if __name__ == "__main__":
ray_params=ray_params,
xgboost_params=None,
)
ray.get(train.remote())
taken = time.time() - start
result = {

View file

@ -4,7 +4,7 @@
compute_template: tpl_cpu_small.yaml
run:
# use_connect: True
use_connect: True
timeout: 600
prepare: python wait_cluster.py 4 600
script: python workloads/train_small.py