mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
64 lines
1.3 KiB
Python
64 lines
1.3 KiB
Python
import os
|
|
import time
|
|
import json
|
|
|
|
import ray
|
|
from ray.util.placement_group import placement_group
|
|
|
|
# Tests are supposed to run for 10 minutes.
|
|
RUNTIME = 600
|
|
NUM_CPU_BUNDLES = 30
|
|
|
|
|
|
@ray.remote(num_cpus=1)
|
|
class Worker(object):
|
|
def __init__(self, i):
|
|
self.i = i
|
|
|
|
def work(self):
|
|
time.sleep(0.1)
|
|
print("work ", self.i)
|
|
|
|
|
|
@ray.remote(num_cpus=1, num_gpus=1)
|
|
class Trainer(object):
|
|
def __init__(self, i):
|
|
self.i = i
|
|
|
|
def train(self):
|
|
time.sleep(0.2)
|
|
print("train ", self.i)
|
|
|
|
|
|
def main():
|
|
ray.init(address="auto")
|
|
|
|
bundles = [{"CPU": 1, "GPU": 1}]
|
|
bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)]
|
|
|
|
pg = placement_group(bundles, strategy="PACK")
|
|
|
|
ray.get(pg.ready())
|
|
|
|
workers = [
|
|
Worker.options(placement_group=pg).remote(i) for i in range(NUM_CPU_BUNDLES)
|
|
]
|
|
|
|
trainer = Trainer.options(placement_group=pg).remote(0)
|
|
|
|
start = time.time()
|
|
while True:
|
|
ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)])
|
|
ray.get(trainer.train.remote())
|
|
end = time.time()
|
|
if end - start > RUNTIME:
|
|
break
|
|
|
|
if "TEST_OUTPUT_JSON" in os.environ:
|
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
|
results = {}
|
|
json.dump(results, out_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|