mirror of
https://github.com/vale981/ray
synced 2025-03-11 13:46:40 -04:00
66 lines
1.3 KiB
Python
66 lines
1.3 KiB
Python
![]() |
import os
|
||
|
import time
|
||
|
import json
|
||
|
|
||
|
import ray
|
||
|
from ray.util.placement_group import placement_group
|
||
|
|
||
|
# Tests are supposed to run for 10 minutes.
|
||
|
RUNTIME = 600
|
||
|
NUM_CPU_BUNDLES = 30
|
||
|
|
||
|
|
||
|
@ray.remote(num_cpus=1)
|
||
|
class Worker(object):
|
||
|
def __init__(self, i):
|
||
|
self.i = i
|
||
|
|
||
|
def work(self):
|
||
|
time.sleep(0.1)
|
||
|
print("work ", self.i)
|
||
|
|
||
|
|
||
|
@ray.remote(num_cpus=1, num_gpus=1)
|
||
|
class Trainer(object):
|
||
|
def __init__(self, i):
|
||
|
self.i = i
|
||
|
|
||
|
def train(self):
|
||
|
time.sleep(0.2)
|
||
|
print("train ", self.i)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
ray.init(address="auto")
|
||
|
|
||
|
bundles = [{"CPU": 1, "GPU": 1}]
|
||
|
bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)]
|
||
|
|
||
|
pg = placement_group(bundles, strategy="PACK")
|
||
|
|
||
|
ray.get(pg.ready())
|
||
|
|
||
|
workers = [
|
||
|
Worker.options(placement_group=pg).remote(i)
|
||
|
for i in range(NUM_CPU_BUNDLES)
|
||
|
]
|
||
|
|
||
|
trainer = Trainer.options(placement_group=pg).remote(0)
|
||
|
|
||
|
start = time.time()
|
||
|
while True:
|
||
|
ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)])
|
||
|
ray.get(trainer.train.remote())
|
||
|
end = time.time()
|
||
|
if end - start > RUNTIME:
|
||
|
break
|
||
|
|
||
|
if "TEST_OUTPUT_JSON" in os.environ:
|
||
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||
|
results = {}
|
||
|
json.dump(results, out_file)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|