ray/release/nightly_tests/placement_group_tests/pg_run.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

64 lines
1.3 KiB
Python

import os
import time
import json
import ray
from ray.util.placement_group import placement_group
# Tests are supposed to run for 10 minutes.
RUNTIME = 600
NUM_CPU_BUNDLES = 30
@ray.remote(num_cpus=1)
class Worker(object):
def __init__(self, i):
self.i = i
def work(self):
time.sleep(0.1)
print("work ", self.i)
@ray.remote(num_cpus=1, num_gpus=1)
class Trainer(object):
def __init__(self, i):
self.i = i
def train(self):
time.sleep(0.2)
print("train ", self.i)
def main():
ray.init(address="auto")
bundles = [{"CPU": 1, "GPU": 1}]
bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)]
pg = placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
workers = [
Worker.options(placement_group=pg).remote(i) for i in range(NUM_CPU_BUNDLES)
]
trainer = Trainer.options(placement_group=pg).remote(0)
start = time.time()
while True:
ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)])
ray.get(trainer.train.remote())
end = time.time()
if end - start > RUNTIME:
break
if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
results = {}
json.dump(results, out_file)
if __name__ == "__main__":
main()