ray/ci/jenkins_tests/multi_node_tests/remove_driver_test.py

274 lines
9.6 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import ray
from ray.tests.utils import (_wait_for_nodes_to_join, _broadcast_event,
_wait_for_event, wait_for_pid_to_exit)
# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a
# total of 10 GPUs. It should be run with 7 drivers. Drivers 2 through 6 must
# run on different nodes so they can check if all the relevant workers on all
# the nodes have been killed.
total_num_nodes = 5
def actor_event_name(driver_index, actor_index):
return "DRIVER_{}_ACTOR_{}_RUNNING".format(driver_index, actor_index)
def remote_function_event_name(driver_index, task_index):
return "DRIVER_{}_TASK_{}_RUNNING".format(driver_index, task_index)
@ray.remote
def long_running_task(driver_index, task_index, redis_address):
_broadcast_event(
remote_function_event_name(driver_index, task_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
# Loop forever.
while True:
time.sleep(100)
num_long_running_tasks_per_driver = 2
@ray.remote
class Actor0(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 0
def check_ids(self):
assert len(ray.get_gpu_ids()) == 0
def long_running_method(self):
# Loop forever.
while True:
time.sleep(100)
@ray.remote(num_gpus=1)
class Actor1(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 1
def check_ids(self):
assert len(ray.get_gpu_ids()) == 1
def long_running_method(self):
# Loop forever.
while True:
time.sleep(100)
@ray.remote(num_gpus=2)
class Actor2(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 2
def check_ids(self):
assert len(ray.get_gpu_ids()) == 2
def long_running_method(self):
# Loop forever.
while True:
time.sleep(100)
def driver_0(redis_address, driver_index):
"""The script for driver 0.
This driver should create five actors that each use one GPU and some actors
that use no GPUs. After a while, it should exit.
"""
ray.init(redis_address=redis_address)
# Wait for all the nodes to join the cluster.
_wait_for_nodes_to_join(total_num_nodes)
# Start some long running task. Driver 2 will make sure the worker running
# this task has been killed.
for i in range(num_long_running_tasks_per_driver):
long_running_task.remote(driver_index, i, redis_address)
# Create some actors that require one GPU.
actors_one_gpu = [
Actor1.remote(driver_index, i, redis_address) for i in range(5)
]
# Create some actors that don't require any GPUs.
actors_no_gpus = [
Actor0.remote(driver_index, 5 + i, redis_address) for i in range(5)
]
for _ in range(1000):
ray.get([actor.check_ids.remote() for actor in actors_one_gpu])
ray.get([actor.check_ids.remote() for actor in actors_no_gpus])
# Start a long-running method on one actor and make sure this doesn't
# affect anything.
actors_no_gpus[0].long_running_method.remote()
_broadcast_event("DRIVER_0_DONE", redis_address)
def driver_1(redis_address, driver_index):
"""The script for driver 1.
This driver should create one actor that uses two GPUs, three actors that
each use one GPU (the one requiring two must be created first), and some
actors that don't use any GPUs. After a while, it should exit.
"""
ray.init(redis_address=redis_address)
# Wait for all the nodes to join the cluster.
_wait_for_nodes_to_join(total_num_nodes)
# Start some long running task. Driver 2 will make sure the worker running
# this task has been killed.
for i in range(num_long_running_tasks_per_driver):
long_running_task.remote(driver_index, i, redis_address)
# Create an actor that requires two GPUs.
actors_two_gpus = [
Actor2.remote(driver_index, i, redis_address) for i in range(1)
]
# Create some actors that require one GPU.
actors_one_gpu = [
Actor1.remote(driver_index, 1 + i, redis_address) for i in range(3)
]
# Create some actors that don't require any GPUs.
actors_no_gpus = [
Actor0.remote(driver_index, 1 + 3 + i, redis_address) for i in range(5)
]
for _ in range(1000):
ray.get([actor.check_ids.remote() for actor in actors_two_gpus])
ray.get([actor.check_ids.remote() for actor in actors_one_gpu])
ray.get([actor.check_ids.remote() for actor in actors_no_gpus])
# Start a long-running method on one actor and make sure this doesn't
# affect anything.
actors_one_gpu[0].long_running_method.remote()
_broadcast_event("DRIVER_1_DONE", redis_address)
def cleanup_driver(redis_address, driver_index):
"""The script for drivers 2 through 6.
This driver should wait for the first two drivers to finish. Then it should
create some actors that use a total of ten GPUs.
"""
ray.init(redis_address=redis_address)
# Only one of the cleanup drivers should create more actors.
if driver_index == 2:
# We go ahead and create some actors that don't require any GPUs. We
# don't need to wait for the other drivers to finish. We call methods
# on these actors later to make sure they haven't been killed.
actors_no_gpus = [
Actor0.remote(driver_index, i, redis_address) for i in range(10)
]
_wait_for_event("DRIVER_0_DONE", redis_address)
_wait_for_event("DRIVER_1_DONE", redis_address)
def try_to_create_actor(actor_class, driver_index, actor_index,
timeout=20):
# Try to create an actor, but allow failures while we wait for the
# monitor to release the resources for the removed drivers.
start_time = time.time()
while time.time() - start_time < timeout:
try:
actor = actor_class.remote(driver_index, actor_index,
redis_address)
except Exception:
time.sleep(0.1)
else:
return actor
# If we are here, then we timed out while looping.
raise Exception("Timed out while trying to create actor.")
# Only one of the cleanup drivers should create more actors.
if driver_index == 2:
# Create some actors that require one GPU.
actors_one_gpu = []
for i in range(10):
actors_one_gpu.append(
try_to_create_actor(Actor1, driver_index, 10 + 3 + i))
removed_workers = 0
# Make sure that the PIDs for the long-running tasks from driver 0 and
# driver 1 have been killed.
for i in range(num_long_running_tasks_per_driver):
node_ip_address, pid = _wait_for_event(
remote_function_event_name(0, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1
for i in range(num_long_running_tasks_per_driver):
node_ip_address, pid = _wait_for_event(
remote_function_event_name(1, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1
# Make sure that the PIDs for the actors from driver 0 and driver 1 have
# been killed.
for i in range(10):
node_ip_address, pid = _wait_for_event(
actor_event_name(0, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1
for i in range(9):
node_ip_address, pid = _wait_for_event(
actor_event_name(1, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1
print("{} workers/actors were removed on this node."
.format(removed_workers))
# Only one of the cleanup drivers should create and use more actors.
if driver_index == 2:
for _ in range(1000):
ray.get([actor.check_ids.remote() for actor in actors_one_gpu])
ray.get([actor.check_ids.remote() for actor in actors_no_gpus])
_broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address)
if __name__ == "__main__":
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
redis_address = os.environ["RAY_REDIS_ADDRESS"]
print("Driver {} started at {}.".format(driver_index, time.time()))
if driver_index == 0:
driver_0(redis_address, driver_index)
elif driver_index == 1:
driver_1(redis_address, driver_index)
elif driver_index in [2, 3, 4, 5, 6]:
cleanup_driver(redis_address, driver_index)
else:
raise Exception("This code should be unreachable.")
print("Driver {} finished at {}.".format(driver_index, time.time()))