mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
58 lines
1.4 KiB
Python
58 lines
1.4 KiB
Python
![]() |
from collections import Counter
|
||
|
import time
|
||
|
|
||
|
import ray
|
||
|
|
||
|
from ray import tune
|
||
|
|
||
|
|
||
|
def train(config):
|
||
|
this_node_ip = ray.util.get_node_ip_address()
|
||
|
if config["head_node_ip"] == this_node_ip:
|
||
|
# On the head node, run for 30 minutes
|
||
|
for i in range(30):
|
||
|
tune.report(metric=i)
|
||
|
time.sleep(60)
|
||
|
else:
|
||
|
# On worker nodes, run for 3 minutes
|
||
|
for i in range(3):
|
||
|
tune.report(metric=i)
|
||
|
time.sleep(60)
|
||
|
|
||
|
|
||
|
class NodeCountCallback(tune.Callback):
|
||
|
def __init__(self):
|
||
|
self.node_counts = []
|
||
|
|
||
|
def on_step_begin(self, iteration, trials, **info):
|
||
|
node_count = len([n for n in ray.nodes() if n["Alive"]])
|
||
|
self.node_counts.append(node_count)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
ray.init()
|
||
|
|
||
|
head_node_ip = ray.util.get_node_ip_address()
|
||
|
|
||
|
assert (
|
||
|
len([n for n in ray.nodes() if n["Alive"]]) == 1
|
||
|
), "Too many nodes available at start of script"
|
||
|
|
||
|
node_counter = NodeCountCallback()
|
||
|
|
||
|
tune.run(
|
||
|
train,
|
||
|
num_samples=3,
|
||
|
config={"head_node_ip": head_node_ip},
|
||
|
callbacks=[node_counter],
|
||
|
resources_per_trial={"cpu": 4},
|
||
|
)
|
||
|
|
||
|
node_counts = Counter(node_counter.node_counts)
|
||
|
assert node_counts[3] > 0, "Cluster never scaled to 3 nodes"
|
||
|
assert node_counter.node_counts[-1] == 1, "Cluster didn't scale down to 1 node."
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|