mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Clean up when a driver disconnects. (#462)
* Clean up state when drivers exit. * Remove unnecessary field in ActorMapEntry struct. * Have monitor release GPU resources in Redis when driver exits. * Enable multiple drivers in multi-node tests and test driver cleanup. * Make redis GPU allocation a redis transaction and small cleanups. * Fix multi-node test. * Small cleanups. * Make global scheduler take node_ip_address so it appears in the right place in the client table. * Cleanups. * Fix linting and cleanups in local scheduler. * Fix removed_driver_test. * Fix bug related to vector -> list. * Fix linting. * Cleanup. * Fix multi node tests. * Fix jenkins tests. * Add another multi node test with many drivers. * Fix linting. * Make the actor creation notification a flatbuffer message. * Revert "Make the actor creation notification a flatbuffer message." This reverts commit af99099c8084dbf9177fb4e34c0c9b1a12c78f39. * Add comment explaining flatbuffer problems.
This commit is contained in:
parent
8194b71f32
commit
0ac125e9b2
31 changed files with 1119 additions and 168 deletions
|
@ -7,13 +7,14 @@ import inspect
|
|||
import json
|
||||
import numpy as np
|
||||
import random
|
||||
import redis
|
||||
import traceback
|
||||
|
||||
import ray.local_scheduler
|
||||
import ray.pickling as pickling
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
import ray.experimental.state as state
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
# This is a variable used by each actor to indicate the IDs of the GPUs that
|
||||
# the worker is currently allowed to use.
|
||||
|
@ -105,6 +106,72 @@ def fetch_and_register_actor(key, worker):
|
|||
# the actor.
|
||||
|
||||
|
||||
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
|
||||
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
|
||||
|
||||
Args:
|
||||
num_gpus: The number of GPUs to acquire.
|
||||
driver_id: The ID of the driver responsible for creating the actor.
|
||||
local_scheduler: Information about the local scheduler.
|
||||
|
||||
Returns:
|
||||
A list of the GPU IDs that were successfully acquired. This should have
|
||||
length either equal to num_gpus or equal to 0.
|
||||
"""
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
|
||||
|
||||
gpus_to_acquire = []
|
||||
|
||||
# Attempt to acquire GPU IDs atomically.
|
||||
with worker.redis_client.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the multi/exec
|
||||
# block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
# Figure out which GPUs are currently in use.
|
||||
result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
all_gpu_ids_in_use = []
|
||||
for key in gpus_in_use:
|
||||
all_gpu_ids_in_use += gpus_in_use[key]
|
||||
assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus
|
||||
assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use)
|
||||
|
||||
pipe.multi()
|
||||
|
||||
if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus:
|
||||
# There are enough available GPUs, so try to reserve some.
|
||||
all_gpu_ids = set(range(local_scheduler_total_gpus))
|
||||
for gpu_id in all_gpu_ids_in_use:
|
||||
all_gpu_ids.remove(gpu_id)
|
||||
gpus_to_acquire = list(all_gpu_ids)[:num_gpus]
|
||||
|
||||
# Use the hex driver ID so that the dictionary is JSON serializable.
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex not in gpus_in_use:
|
||||
gpus_in_use[driver_id_hex] = []
|
||||
gpus_in_use[driver_id_hex] += gpus_to_acquire
|
||||
|
||||
# Stick the updated GPU IDs back in Redis
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raised, then the operations should have gone
|
||||
# through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the time we
|
||||
# started WATCHing it and the pipeline's execution. We should just
|
||||
# retry.
|
||||
gpus_to_acquire = []
|
||||
continue
|
||||
|
||||
return gpus_to_acquire
|
||||
|
||||
|
||||
def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
|
||||
|
@ -121,42 +188,33 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
|
|||
Exception: An exception is raised if no local scheduler can be found with
|
||||
sufficient resources.
|
||||
"""
|
||||
# TODO(rkn): We should change this method to have a list of GPU IDs that we
|
||||
# pop from and push to. The current implementation is not compatible with
|
||||
# actors releasing GPU resources.
|
||||
driver_id = worker.task_driver_id.id()
|
||||
|
||||
if num_gpus == 0:
|
||||
local_scheduler_id = random.choice(local_schedulers)[b"ray_client_id"]
|
||||
gpu_ids = []
|
||||
local_scheduler_id = hex_to_binary(
|
||||
random.choice(local_schedulers)["DBClientID"])
|
||||
gpus_aquired = []
|
||||
else:
|
||||
# All of this logic is for finding a local scheduler that has enough
|
||||
# available GPUs.
|
||||
local_scheduler_id = None
|
||||
# Loop through all of the local schedulers.
|
||||
for local_scheduler in local_schedulers:
|
||||
# See if there are enough available GPUs on this local scheduler.
|
||||
local_scheduler_total_gpus = int(float(
|
||||
local_scheduler[b"num_gpus"].decode("ascii")))
|
||||
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
|
||||
b"gpus_in_use")
|
||||
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
|
||||
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
|
||||
# Attempt to reserve some GPUs for this actor.
|
||||
new_gpus_in_use = worker.redis_client.hincrby(
|
||||
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
|
||||
if new_gpus_in_use > local_scheduler_total_gpus:
|
||||
# If we failed to reserve the GPUs, undo the increment.
|
||||
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
|
||||
b"gpus_in_use", num_gpus)
|
||||
else:
|
||||
# We succeeded at reserving the GPUs, so we are done.
|
||||
local_scheduler_id = local_scheduler[b"ray_client_id"]
|
||||
gpu_ids = list(range(new_gpus_in_use - num_gpus, new_gpus_in_use))
|
||||
break
|
||||
# Try to reserve enough GPUs on this local scheduler.
|
||||
gpus_aquired = attempt_to_reserve_gpus(num_gpus, driver_id,
|
||||
local_scheduler, worker)
|
||||
if len(gpus_aquired) == num_gpus:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
else:
|
||||
# We should have either acquired as many GPUs as we need or none.
|
||||
assert len(gpus_aquired) == 0
|
||||
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs to create this "
|
||||
"actor. The local scheduler information is {}."
|
||||
.format(local_schedulers))
|
||||
return local_scheduler_id, gpu_ids
|
||||
return local_scheduler_id, gpus_aquired
|
||||
|
||||
|
||||
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
|
||||
|
@ -183,13 +241,23 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
|
|||
worker.function_properties[driver_id][function_id] = (1, num_cpus,
|
||||
num_gpus)
|
||||
|
||||
# Get a list of the local schedulers from the client table.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
# Select a local scheduler for the actor.
|
||||
local_schedulers = state.get_local_schedulers(worker)
|
||||
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
|
||||
num_gpus, worker)
|
||||
|
||||
# Really we should encode this message as a flatbuffer object. However, we're
|
||||
# having trouble getting that to work. It almost works, but in Python 2.7,
|
||||
# builder.CreateString fails on byte strings that contain characters outside
|
||||
# range(128).
|
||||
worker.redis_client.publish("actor_notifications",
|
||||
actor_id.id() + local_scheduler_id)
|
||||
actor_id.id() + driver_id + local_scheduler_id)
|
||||
|
||||
d = {"driver_id": driver_id,
|
||||
"actor_id": actor_id.id(),
|
||||
|
|
|
@ -2,12 +2,11 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import pickle
|
||||
import redis
|
||||
import sys
|
||||
|
||||
import ray.local_scheduler
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.TaskInfo import TaskInfo
|
||||
|
@ -36,40 +35,6 @@ task_state_mapping = {
|
|||
}
|
||||
|
||||
|
||||
def decode(byte_str):
|
||||
"""Make this unicode in Python 3, otherwise leave it as bytes."""
|
||||
if sys.version_info >= (3, 0):
|
||||
return byte_str.decode("ascii")
|
||||
else:
|
||||
return byte_str
|
||||
|
||||
|
||||
def binary_to_object_id(binary_object_id):
|
||||
return ray.local_scheduler.ObjectID(binary_object_id)
|
||||
|
||||
|
||||
def binary_to_hex(identifier):
|
||||
hex_identifier = binascii.hexlify(identifier)
|
||||
if sys.version_info >= (3, 0):
|
||||
hex_identifier = hex_identifier.decode()
|
||||
return hex_identifier
|
||||
|
||||
|
||||
def hex_to_binary(hex_identifier):
|
||||
return binascii.unhexlify(hex_identifier)
|
||||
|
||||
|
||||
def get_local_schedulers(worker):
|
||||
local_schedulers = []
|
||||
for client in worker.redis_client.keys("CL:*"):
|
||||
client_info = worker.redis_client.hgetall(client)
|
||||
if b"client_type" not in client_info:
|
||||
continue
|
||||
if client_info[b"client_type"] == b"local_scheduler":
|
||||
local_schedulers.append(client_info)
|
||||
return local_schedulers
|
||||
|
||||
|
||||
class GlobalState(object):
|
||||
"""A class used to interface with the Ray control state.
|
||||
|
||||
|
|
|
@ -7,13 +7,15 @@ import subprocess
|
|||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, use_valgrind=False,
|
||||
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
|
||||
use_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
redis_address (str): The address of the Redis instance.
|
||||
node_ip_address: The IP address of the node that this scheduler will run
|
||||
on.
|
||||
use_valgrind (bool): True if the global scheduler should be started inside
|
||||
of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the global scheduler should be started inside
|
||||
|
@ -31,7 +33,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
|
|||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable, "-r", redis_address]
|
||||
command = [global_scheduler_executable,
|
||||
"-r", redis_address,
|
||||
"-h", node_ip_address]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
|
|
|
@ -71,7 +71,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
port=redis_port)
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, use_valgrind=USE_VALGRIND)
|
||||
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
|
|
|
@ -4,10 +4,12 @@ from __future__ import print_function
|
|||
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import json
|
||||
import logging
|
||||
import redis
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
|
||||
|
@ -15,6 +17,7 @@ from ray.services import get_port
|
|||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
# common/common.h
|
||||
|
@ -26,6 +29,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
|||
TASK_STATUS_LOST = 32
|
||||
# common/state/redis.cc
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
TASK_PREFIX = "TT:"
|
||||
OBJECT_PREFIX = "OL:"
|
||||
|
@ -215,6 +219,65 @@ class Monitor(object):
|
|||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
|
||||
def driver_removed_handler(self, channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
Redis.
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed.".format(driver_id))
|
||||
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
if int(local_scheduler["NumGPUs"]) > 0:
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
|
||||
returned_gpu_ids = []
|
||||
|
||||
# Perform a transaction to return the GPUs.
|
||||
with self.redis.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
result = pipe.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
|
||||
driver_id_hex = ray.utils.binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
returned_gpu_ids = gpus_in_use.pop(driver_id_hex)
|
||||
|
||||
pipe.multi()
|
||||
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raise, then the operations should have
|
||||
# gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
continue
|
||||
|
||||
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
|
||||
.format(driver_id, returned_gpu_ids, local_scheduler_id))
|
||||
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
|
||||
|
@ -244,6 +307,12 @@ class Monitor(object):
|
|||
assert(self.subscribed[channel])
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification that a driver was removed.
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
|
@ -258,6 +327,7 @@ class Monitor(object):
|
|||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
|
||||
# Scan the database table for dead database clients. NOTE: This must be
|
||||
# called before reading any messages from the subscription channel. This
|
||||
|
@ -326,5 +396,8 @@ if __name__ == "__main__":
|
|||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
|
||||
# Initialize the global state.
|
||||
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
|
||||
|
||||
monitor = Monitor(redis_ip_address, redis_port)
|
||||
monitor.run()
|
||||
|
|
|
@ -372,7 +372,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
|||
this process will be killed by services.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
"""
|
||||
p = global_scheduler.start_global_scheduler(redis_address,
|
||||
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
if cleanup:
|
||||
|
@ -767,7 +767,10 @@ def start_ray_processes(address_info=None,
|
|||
if num_workers is not None:
|
||||
workers_per_local_scheduler = num_local_schedulers * [num_workers]
|
||||
else:
|
||||
workers_per_local_scheduler = num_local_schedulers * [psutil.cpu_count()]
|
||||
workers_per_local_scheduler = []
|
||||
for cpus in num_cpus:
|
||||
workers_per_local_scheduler.append(cpus if cpus is not None
|
||||
else psutil.cpu_count())
|
||||
|
||||
if address_info is None:
|
||||
address_info = {}
|
||||
|
|
86
python/ray/test/multi_node_tests.py
Normal file
86
python/ray/test/multi_node_tests.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import redis
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
|
||||
"""This key is used internally within this file for coordinating drivers."""
|
||||
|
||||
|
||||
def _wait_for_nodes_to_join(num_nodes, timeout=20):
|
||||
"""Wait until the nodes have joined the cluster.
|
||||
|
||||
This will wait until exactly num_nodes have joined the cluster and each node
|
||||
has a local scheduler and a plasma manager.
|
||||
|
||||
Args:
|
||||
num_nodes: The number of nodes to wait for.
|
||||
timeout: The amount of time in seconds to wait before failing.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if too many nodes join the cluster or if
|
||||
the timeout expires while we are waiting.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
client_table = ray.global_state.client_table()
|
||||
num_ready_nodes = len(client_table)
|
||||
if num_ready_nodes == num_nodes:
|
||||
ready = True
|
||||
# Check that for each node, a local scheduler and a plasma manager are
|
||||
# present.
|
||||
for ip_address, clients in client_table.items():
|
||||
client_types = [client["ClientType"] for client in clients]
|
||||
if "local_scheduler" not in client_types:
|
||||
ready = False
|
||||
if "plasma_manager" not in client_types:
|
||||
ready = False
|
||||
if ready:
|
||||
return
|
||||
if num_ready_nodes > num_nodes:
|
||||
# Too many nodes have joined. Something must be wrong.
|
||||
raise Exception("{} nodes have joined the cluster, but we were "
|
||||
"expecting {} nodes.".format(num_ready_nodes, num_nodes))
|
||||
time.sleep(0.1)
|
||||
|
||||
# If we get here then we timed out.
|
||||
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
|
||||
"nodes have joined so far.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
|
||||
|
||||
def _broadcast_event(event_name, redis_address):
|
||||
"""Broadcast an event.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for synchronization.
|
||||
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
redis_client.rpush(EVENT_KEY, event_name)
|
||||
|
||||
|
||||
def _wait_for_event(event_name, redis_address, extra_buffer=1):
|
||||
"""Block until an event has been broadcast.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for synchronization.
|
||||
extra_buffer: An amount of time in seconds to wait after the event.
|
||||
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
while True:
|
||||
event_names = redis_client.lrange(EVENT_KEY, 0, -1)
|
||||
if event_name.encode("ascii") in event_names:
|
||||
break
|
||||
time.sleep(extra_buffer)
|
31
python/ray/utils.py
Normal file
31
python/ray/utils.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import sys
|
||||
|
||||
import ray.local_scheduler
|
||||
|
||||
|
||||
def decode(byte_str):
|
||||
"""Make this unicode in Python 3, otherwise leave it as bytes."""
|
||||
if sys.version_info >= (3, 0):
|
||||
return byte_str.decode("ascii")
|
||||
else:
|
||||
return byte_str
|
||||
|
||||
|
||||
def binary_to_object_id(binary_object_id):
|
||||
return ray.local_scheduler.ObjectID(binary_object_id)
|
||||
|
||||
|
||||
def binary_to_hex(identifier):
|
||||
hex_identifier = binascii.hexlify(identifier)
|
||||
if sys.version_info >= (3, 0):
|
||||
hex_identifier = hex_identifier.decode()
|
||||
return hex_identifier
|
||||
|
||||
|
||||
def hex_to_binary(hex_identifier):
|
||||
return binascii.unhexlify(hex_identifier)
|
|
@ -61,6 +61,7 @@ add_library(common STATIC
|
|||
state/object_table.cc
|
||||
state/task_table.cc
|
||||
state/db_client_table.cc
|
||||
state/driver_table.cc
|
||||
state/actor_notification_table.cc
|
||||
state/local_scheduler_table.cc
|
||||
state/error_table.cc
|
||||
|
|
|
@ -47,6 +47,10 @@ bool DBClientID_equal(DBClientID first_id, DBClientID second_id) {
|
|||
return UNIQUE_ID_EQ(first_id, second_id);
|
||||
}
|
||||
|
||||
bool WorkerID_equal(WorkerID first_id, WorkerID second_id) {
|
||||
return UNIQUE_ID_EQ(first_id, second_id);
|
||||
}
|
||||
|
||||
char *ObjectID_to_string(ObjectID obj_id, char *id_string, int id_length) {
|
||||
CHECK(id_length >= ID_STRING_SIZE);
|
||||
static const char hex[] = "0123456789abcdef";
|
||||
|
|
|
@ -153,7 +153,9 @@ extern const UniqueID NIL_ID;
|
|||
UniqueID globally_unique_id(void);
|
||||
|
||||
#define NIL_OBJECT_ID NIL_ID
|
||||
#define NIL_WORKER_ID NIL_ID
|
||||
|
||||
/** The object ID is the type used to identify objects. */
|
||||
typedef UniqueID ObjectID;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -202,6 +204,18 @@ bool ObjectID_equal(ObjectID first_id, ObjectID second_id);
|
|||
*/
|
||||
bool ObjectID_is_nil(ObjectID id);
|
||||
|
||||
/** The worker ID is the ID of a worker or driver. */
|
||||
typedef UniqueID WorkerID;
|
||||
|
||||
/**
|
||||
* Compare two worker IDs.
|
||||
*
|
||||
* @param first_id The first worker ID to compare.
|
||||
* @param second_id The first worker ID to compare.
|
||||
* @return True if the worker IDs are the same and false otherwise.
|
||||
*/
|
||||
bool WorkerID_equal(WorkerID first_id, WorkerID second_id);
|
||||
|
||||
typedef UniqueID DBClientID;
|
||||
|
||||
/**
|
||||
|
|
|
@ -139,3 +139,8 @@ table ResultTableReply {
|
|||
}
|
||||
|
||||
root_type ResultTableReply;
|
||||
|
||||
table DriverTableMessage {
|
||||
// The driver ID of the driver that died.
|
||||
driver_id: string;
|
||||
}
|
||||
|
|
|
@ -5,20 +5,16 @@
|
|||
#include "db.h"
|
||||
#include "table.h"
|
||||
|
||||
typedef struct {
|
||||
/** The ID of the actor. */
|
||||
ActorID actor_id;
|
||||
/** The ID of the local scheduler that is responsible for the actor. */
|
||||
DBClientID local_scheduler_id;
|
||||
} ActorInfo;
|
||||
|
||||
/*
|
||||
* ==== Subscribing to the actor notification table ====
|
||||
*/
|
||||
|
||||
/* Callback for subscribing to the local scheduler table. */
|
||||
typedef void (*actor_notification_table_subscribe_callback)(ActorInfo info,
|
||||
void *user_context);
|
||||
typedef void (*actor_notification_table_subscribe_callback)(
|
||||
ActorID actor_id,
|
||||
WorkerID driver_id,
|
||||
DBClientID local_scheduler_id,
|
||||
void *user_context);
|
||||
|
||||
/**
|
||||
* Register a callback to process actor notification events.
|
||||
|
|
21
src/common/state/driver_table.cc
Normal file
21
src/common/state/driver_table.cc
Normal file
|
@ -0,0 +1,21 @@
|
|||
#include "driver_table.h"
|
||||
#include "redis.h"
|
||||
|
||||
void driver_table_subscribe(DBHandle *db_handle,
|
||||
driver_table_subscribe_callback subscribe_callback,
|
||||
void *subscribe_context,
|
||||
RetryInfo *retry) {
|
||||
DriverTableSubscribeData *sub_data =
|
||||
(DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData));
|
||||
sub_data->subscribe_callback = subscribe_callback;
|
||||
sub_data->subscribe_context = subscribe_context;
|
||||
init_table_callback(db_handle, NIL_ID, __func__, sub_data, retry, NULL,
|
||||
redis_driver_table_subscribe, NULL);
|
||||
}
|
||||
|
||||
void driver_table_send_driver_death(DBHandle *db_handle,
|
||||
WorkerID driver_id,
|
||||
RetryInfo *retry) {
|
||||
init_table_callback(db_handle, driver_id, __func__, NULL, retry, NULL,
|
||||
redis_driver_table_send_driver_death, NULL);
|
||||
}
|
50
src/common/state/driver_table.h
Normal file
50
src/common/state/driver_table.h
Normal file
|
@ -0,0 +1,50 @@
|
|||
#ifndef DRIVER_TABLE_H
|
||||
#define DRIVER_TABLE_H
|
||||
|
||||
#include "db.h"
|
||||
#include "table.h"
|
||||
#include "task.h"
|
||||
|
||||
/*
|
||||
* ==== Subscribing to the driver table ====
|
||||
*/
|
||||
|
||||
/* Callback for subscribing to the driver table. */
|
||||
typedef void (*driver_table_subscribe_callback)(WorkerID driver_id,
|
||||
void *user_context);
|
||||
|
||||
/**
|
||||
* Register a callback for a driver table event.
|
||||
*
|
||||
* @param db_handle Database handle.
|
||||
* @param subscribe_callback Callback that will be called when the driver event
|
||||
* happens.
|
||||
* @param subscribe_context Context that will be passed into the
|
||||
* subscribe_callback.
|
||||
* @param retry Information about retrying the request to the database.
|
||||
* @return Void.
|
||||
*/
|
||||
void driver_table_subscribe(DBHandle *db_handle,
|
||||
driver_table_subscribe_callback subscribe_callback,
|
||||
void *subscribe_context,
|
||||
RetryInfo *retry);
|
||||
|
||||
/* Data that is needed to register driver table subscribe callbacks with the
|
||||
* state database. */
|
||||
typedef struct {
|
||||
driver_table_subscribe_callback subscribe_callback;
|
||||
void *subscribe_context;
|
||||
} DriverTableSubscribeData;
|
||||
|
||||
/**
|
||||
* Send driver death update to all subscribers.
|
||||
*
|
||||
* @param db_handle Database handle.
|
||||
* @param driver_id The ID of the driver that died.
|
||||
* @param retry Information about retrying the request to the database.
|
||||
*/
|
||||
void driver_table_send_driver_death(DBHandle *db_handle,
|
||||
WorkerID driver_id,
|
||||
RetryInfo *retry);
|
||||
|
||||
#endif /* DRIVER_TABLE_H */
|
|
@ -17,6 +17,7 @@ extern "C" {
|
|||
#include "db.h"
|
||||
#include "db_client_table.h"
|
||||
#include "actor_notification_table.h"
|
||||
#include "driver_table.h"
|
||||
#include "local_scheduler_table.h"
|
||||
#include "object_table.h"
|
||||
#include "task.h"
|
||||
|
@ -1089,6 +1090,90 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) {
|
|||
}
|
||||
}
|
||||
|
||||
void redis_driver_table_subscribe_callback(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata) {
|
||||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
|
||||
redisReply *reply = (redisReply *) r;
|
||||
CHECK(reply->type == REDIS_REPLY_ARRAY);
|
||||
CHECK(reply->elements == 3);
|
||||
redisReply *message_type = reply->element[0];
|
||||
LOG_DEBUG("Driver table subscribe callback, message %s", message_type->str);
|
||||
|
||||
if (strcmp(message_type->str, "message") == 0) {
|
||||
/* Handle a driver heartbeat. Parse the payload and call the subscribe
|
||||
* callback. */
|
||||
auto message =
|
||||
flatbuffers::GetRoot<DriverTableMessage>(reply->element[2]->str);
|
||||
/* Extract the client ID. */
|
||||
WorkerID driver_id = from_flatbuf(message->driver_id());
|
||||
|
||||
/* Call the subscribe callback. */
|
||||
DriverTableSubscribeData *data =
|
||||
(DriverTableSubscribeData *) callback_data->data;
|
||||
if (data->subscribe_callback) {
|
||||
data->subscribe_callback(driver_id, data->subscribe_context);
|
||||
}
|
||||
} else if (strcmp(message_type->str, "subscribe") == 0) {
|
||||
/* The reply for the initial SUBSCRIBE command. */
|
||||
CHECK(callback_data->done_callback == NULL);
|
||||
/* If the initial SUBSCRIBE was successful, clean up the timer, but don't
|
||||
* destroy the callback data. */
|
||||
event_loop_remove_timer(db->loop, callback_data->timer_id);
|
||||
|
||||
} else {
|
||||
LOG_FATAL("Unexpected reply type from driver subscribe.");
|
||||
}
|
||||
}
|
||||
|
||||
void redis_driver_table_subscribe(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_driver_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE driver_deaths");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe");
|
||||
}
|
||||
}
|
||||
|
||||
void redis_driver_table_send_driver_death_callback(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata) {
|
||||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
|
||||
redisReply *reply = (redisReply *) r;
|
||||
CHECK(reply->type == REDIS_REPLY_INTEGER);
|
||||
LOG_DEBUG("%" PRId64 " subscribers received this publish.\n", reply->integer);
|
||||
/* At the very least, the local scheduler that publishes this message should
|
||||
* also receive it. */
|
||||
CHECK(reply->integer >= 1);
|
||||
|
||||
CHECK(callback_data->done_callback == NULL);
|
||||
/* Clean up the timer and callback. */
|
||||
destroy_timer_callback(db->loop, callback_data);
|
||||
}
|
||||
|
||||
void redis_driver_table_send_driver_death(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
WorkerID driver_id = callback_data->id;
|
||||
|
||||
/* Create a flatbuffer object to serialize and publish. */
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
/* Create the flatbuffers message. */
|
||||
auto message = CreateDriverTableMessage(fbb, to_flatbuf(fbb, driver_id));
|
||||
fbb.Finish(message);
|
||||
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_driver_table_send_driver_death_callback,
|
||||
(void *) callback_data->timer_id, "PUBLISH driver_deaths %b",
|
||||
fbb.GetBufferPointer(), fbb.GetSize());
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context,
|
||||
"error in redis_driver_table_send_driver_death");
|
||||
}
|
||||
}
|
||||
|
||||
void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
/* NOTE(swang): We purposefully do not provide a callback, leaving the table
|
||||
|
@ -1124,15 +1209,20 @@ void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c,
|
|||
redisReply *payload = reply->element[2];
|
||||
ActorNotificationTableSubscribeData *data =
|
||||
(ActorNotificationTableSubscribeData *) callback_data->data;
|
||||
ActorInfo info;
|
||||
/* The payload should be the concatenation of these two structs. */
|
||||
CHECK(sizeof(info.actor_id) + sizeof(info.local_scheduler_id) ==
|
||||
/* The payload should be the concatenation of three IDs. */
|
||||
ActorID actor_id;
|
||||
WorkerID driver_id;
|
||||
DBClientID local_scheduler_id;
|
||||
CHECK(sizeof(actor_id) + sizeof(driver_id) + sizeof(local_scheduler_id) ==
|
||||
payload->len);
|
||||
memcpy(&info.actor_id, payload->str, sizeof(info.actor_id));
|
||||
memcpy(&info.local_scheduler_id, payload->str + sizeof(info.actor_id),
|
||||
sizeof(info.local_scheduler_id));
|
||||
memcpy(&actor_id, payload->str, sizeof(actor_id));
|
||||
memcpy(&driver_id, payload->str + sizeof(actor_id), sizeof(driver_id));
|
||||
memcpy(&local_scheduler_id,
|
||||
payload->str + sizeof(actor_id) + sizeof(driver_id),
|
||||
sizeof(local_scheduler_id));
|
||||
if (data->subscribe_callback) {
|
||||
data->subscribe_callback(info, data->subscribe_context);
|
||||
data->subscribe_callback(actor_id, driver_id, local_scheduler_id,
|
||||
data->subscribe_context);
|
||||
}
|
||||
} else if (strcmp(message_type->str, "subscribe") == 0) {
|
||||
/* The reply for the initial SUBSCRIBE command. */
|
||||
|
|
|
@ -253,6 +253,24 @@ void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data);
|
|||
*/
|
||||
void redis_local_scheduler_table_send_info(TableCallbackData *callback_data);
|
||||
|
||||
/**
|
||||
* Subscribe to updates from the driver table.
|
||||
*
|
||||
* @param callback_data Data structure containing redis connection and timeout
|
||||
* information.
|
||||
* @return Void.
|
||||
*/
|
||||
void redis_driver_table_subscribe(TableCallbackData *callback_data);
|
||||
|
||||
/**
|
||||
* Publish an update to the driver table.
|
||||
*
|
||||
* @param callback_data Data structure containing redis connection and timeout
|
||||
* information.
|
||||
* @return Void.
|
||||
*/
|
||||
void redis_driver_table_send_driver_death(TableCallbackData *callback_data);
|
||||
|
||||
void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data);
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,7 +16,6 @@ struct TaskBuilder;
|
|||
#define NIL_TASK_ID NIL_ID
|
||||
#define NIL_ACTOR_ID NIL_ID
|
||||
#define NIL_FUNCTION_ID NIL_ID
|
||||
#define NIL_WORKER_ID NIL_ID
|
||||
|
||||
typedef UniqueID FunctionID;
|
||||
|
||||
|
|
|
@ -61,14 +61,15 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state,
|
|||
}
|
||||
|
||||
GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
|
||||
const char *node_ip_address,
|
||||
const char *redis_addr,
|
||||
int redis_port) {
|
||||
GlobalSchedulerState *state =
|
||||
(GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState));
|
||||
/* Must initialize state to 0. Sets hashmap head(s) to NULL. */
|
||||
memset(state, 0, sizeof(GlobalSchedulerState));
|
||||
state->db =
|
||||
db_connect(redis_addr, redis_port, "global_scheduler", ":", 0, NULL);
|
||||
state->db = db_connect(redis_addr, redis_port, "global_scheduler",
|
||||
node_ip_address, 0, NULL);
|
||||
db_attach(state->db, loop, false);
|
||||
utarray_new(state->local_schedulers, &local_scheduler_icd);
|
||||
state->policy_state = GlobalSchedulerPolicyState_init();
|
||||
|
@ -416,9 +417,12 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) {
|
|||
return HEARTBEAT_TIMEOUT_MILLISECONDS;
|
||||
}
|
||||
|
||||
void start_server(const char *redis_addr, int redis_port) {
|
||||
void start_server(const char *node_ip_address,
|
||||
const char *redis_addr,
|
||||
int redis_port) {
|
||||
event_loop *loop = event_loop_create();
|
||||
g_state = GlobalSchedulerState_init(loop, redis_addr, redis_port);
|
||||
g_state =
|
||||
GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port);
|
||||
/* TODO(rkn): subscribe to notifications from the object table. */
|
||||
/* Subscribe to notifications about new local schedulers. TODO(rkn): this
|
||||
* needs to also get all of the clients that registered with the database
|
||||
|
@ -456,12 +460,17 @@ int main(int argc, char *argv[]) {
|
|||
signal(SIGTERM, signal_handler);
|
||||
/* IP address and port of redis. */
|
||||
char *redis_addr_port = NULL;
|
||||
/* The IP address of the node that this global scheduler is running on. */
|
||||
char *node_ip_address = NULL;
|
||||
int c;
|
||||
while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) {
|
||||
while ((c = getopt(argc, argv, "h:r:")) != -1) {
|
||||
switch (c) {
|
||||
case 'r':
|
||||
redis_addr_port = optarg;
|
||||
break;
|
||||
case 'h':
|
||||
node_ip_address = optarg;
|
||||
break;
|
||||
default:
|
||||
LOG_ERROR("unknown option %c", c);
|
||||
exit(-1);
|
||||
|
@ -472,8 +481,11 @@ int main(int argc, char *argv[]) {
|
|||
if (!redis_addr_port ||
|
||||
parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) {
|
||||
LOG_ERROR(
|
||||
"need to specify redis address like 127.0.0.1:6379 with -r switch");
|
||||
"specify the redis address like 127.0.0.1:6379 with the -r switch");
|
||||
exit(-1);
|
||||
}
|
||||
start_server(redis_addr, redis_port);
|
||||
if (!node_ip_address) {
|
||||
LOG_FATAL("specify the node IP address with the -h switch");
|
||||
}
|
||||
start_server(node_ip_address, redis_addr, redis_port);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "local_scheduler_algorithm.h"
|
||||
#include "state/actor_notification_table.h"
|
||||
#include "state/db.h"
|
||||
#include "state/driver_table.h"
|
||||
#include "state/task_table.h"
|
||||
#include "state/object_table.h"
|
||||
#include "state/error_table.h"
|
||||
|
@ -67,7 +68,8 @@ int force_kill_worker(event_loop *loop, timer_id id, void *context) {
|
|||
|
||||
/**
|
||||
* Kill a worker, if it is a child process, and clean up all of its associated
|
||||
* state.
|
||||
* state. Note that this function is also called on drivers, but it should not
|
||||
* actually send a kill signal to drivers.
|
||||
*
|
||||
* @param worker A pointer to the worker we want to kill.
|
||||
* @param cleanup A bool representing whether we're cleaning up the entire local
|
||||
|
@ -89,7 +91,13 @@ void kill_worker(LocalSchedulerState *state,
|
|||
CHECK(it == state->workers.end());
|
||||
|
||||
/* Erase the algorithm state's reference to the worker. */
|
||||
handle_worker_removed(state, state->algorithm_state, worker);
|
||||
if (ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) {
|
||||
handle_worker_removed(state, state->algorithm_state, worker);
|
||||
} else {
|
||||
/* Let the scheduling algorithm process the absence of this worker. */
|
||||
handle_actor_worker_disconnect(state, state->algorithm_state,
|
||||
worker->actor_id);
|
||||
}
|
||||
|
||||
/* Remove the client socket from the event loop so that we don't process the
|
||||
* SIGPIPE when the worker is killed. */
|
||||
|
@ -99,6 +107,8 @@ void kill_worker(LocalSchedulerState *state,
|
|||
* process, use it to send a kill signal. */
|
||||
bool free_worker = true;
|
||||
if (worker->is_child && worker->pid != 0) {
|
||||
/* If worker is a driver, we should not enter this condition because
|
||||
* worker->pid should be 0. */
|
||||
if (cleanup) {
|
||||
/* If we're exiting the local scheduler anyway, it's okay to force kill
|
||||
* the worker immediately. Wait for the process to exit. */
|
||||
|
@ -425,10 +435,18 @@ void update_dynamic_resources(LocalSchedulerState *state,
|
|||
print_resource_info(state, spec);
|
||||
}
|
||||
|
||||
bool is_driver_alive(LocalSchedulerState *state, WorkerID driver_id) {
|
||||
return state->removed_drivers.count(driver_id) == 0;
|
||||
}
|
||||
|
||||
void assign_task_to_worker(LocalSchedulerState *state,
|
||||
TaskSpec *spec,
|
||||
int64_t task_spec_size,
|
||||
LocalSchedulerClient *worker) {
|
||||
/* Make sure the driver for this task is still alive. */
|
||||
WorkerID driver_id = TaskSpec_driver_id(spec);
|
||||
CHECK(is_driver_alive(state, driver_id));
|
||||
|
||||
/* Construct a flatbuffer object to send to the worker. */
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message =
|
||||
|
@ -649,8 +667,15 @@ void send_client_register_reply(LocalSchedulerState *state,
|
|||
void handle_client_register(LocalSchedulerState *state,
|
||||
LocalSchedulerClient *worker,
|
||||
const RegisterClientRequest *message) {
|
||||
/* Make sure this worker hasn't already registered. */
|
||||
CHECK(!worker->registered);
|
||||
worker->registered = true;
|
||||
worker->is_worker = message->is_worker();
|
||||
CHECK(WorkerID_equal(worker->client_id, NIL_WORKER_ID));
|
||||
worker->client_id = from_flatbuf(message->client_id());
|
||||
|
||||
/* Register the worker or driver. */
|
||||
if (message->is_worker()) {
|
||||
if (worker->is_worker) {
|
||||
/* Update the actor mapping with the actor ID of the worker (if an actor is
|
||||
* running on the worker). */
|
||||
worker->pid = message->worker_pid();
|
||||
|
@ -683,12 +708,74 @@ void handle_client_register(LocalSchedulerState *state,
|
|||
state->child_pids.erase(it);
|
||||
LOG_DEBUG("Found matching child pid %d", worker->pid);
|
||||
}
|
||||
|
||||
/* If the worker is an actor that corresponds to a driver that has been
|
||||
* removed, then kill the worker. */
|
||||
if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) {
|
||||
WorkerID driver_id = state->actor_mapping[actor_id].driver_id;
|
||||
if (state->removed_drivers.count(driver_id) == 1) {
|
||||
kill_worker(state, worker, false);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* Register the driver. Currently we don't do anything here. */
|
||||
}
|
||||
}
|
||||
|
||||
/* End of the cleanup code. */
|
||||
void handle_driver_removed_callback(WorkerID driver_id, void *user_context) {
|
||||
LocalSchedulerState *state = (LocalSchedulerState *) user_context;
|
||||
|
||||
/* Kill any actors that were created by the removed driver, and kill any
|
||||
* workers that are currently running tasks from the dead driver. */
|
||||
auto it = state->workers.begin();
|
||||
while (it != state->workers.end()) {
|
||||
/* Increment the iterator by one before calling kill_worker, because
|
||||
* kill_worker will invalidate the iterator. Note that this requires
|
||||
* knowledge of the particular container that we are iterating over (in this
|
||||
* case it is a list). */
|
||||
auto next_it = it;
|
||||
next_it++;
|
||||
|
||||
ActorID actor_id = (*it)->actor_id;
|
||||
Task *task = (*it)->task_in_progress;
|
||||
|
||||
if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) {
|
||||
/* This is an actor. */
|
||||
CHECK(state->actor_mapping.count(actor_id) == 1);
|
||||
if (WorkerID_equal(state->actor_mapping[actor_id].driver_id, driver_id)) {
|
||||
/* This actor was created by the removed driver, so kill the actor. */
|
||||
LOG_DEBUG("Killing an actor for a removed driver.");
|
||||
kill_worker(state, *it, false);
|
||||
break;
|
||||
}
|
||||
} else if (task != NULL) {
|
||||
if (WorkerID_equal(TaskSpec_driver_id(Task_task_spec(task)), driver_id)) {
|
||||
LOG_DEBUG("Killing a worker executing a task for a removed driver.");
|
||||
kill_worker(state, *it, false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
it = next_it;
|
||||
}
|
||||
|
||||
/* Add the driver to a list of dead drivers. */
|
||||
state->removed_drivers.insert(driver_id);
|
||||
|
||||
/* Notify the scheduling algorithm that the driver has been removed. It should
|
||||
* remove tasks for that driver from its data structures. */
|
||||
handle_driver_removed(state, state->algorithm_state, driver_id);
|
||||
}
|
||||
|
||||
void handle_client_disconnect(LocalSchedulerState *state,
|
||||
LocalSchedulerClient *worker) {
|
||||
if (!worker->registered || worker->is_worker) {
|
||||
} else {
|
||||
/* In this case, a driver is disconecting. */
|
||||
driver_table_send_driver_death(state->db, worker->client_id, NULL);
|
||||
}
|
||||
kill_worker(state, worker, false);
|
||||
}
|
||||
|
||||
void process_message(event_loop *loop,
|
||||
int client_sock,
|
||||
|
@ -786,12 +873,7 @@ void process_message(event_loop *loop,
|
|||
} break;
|
||||
case DISCONNECT_CLIENT: {
|
||||
LOG_INFO("Disconnecting client on fd %d", client_sock);
|
||||
kill_worker(state, worker, false);
|
||||
if (!ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) {
|
||||
/* Let the scheduling algorithm process the absence of this worker. */
|
||||
handle_actor_worker_disconnect(state, state->algorithm_state,
|
||||
worker->actor_id);
|
||||
}
|
||||
handle_client_disconnect(state, worker);
|
||||
} break;
|
||||
case MessageType_NotifyUnblocked: {
|
||||
if (worker->task_in_progress != NULL) {
|
||||
|
@ -827,6 +909,11 @@ void new_client_connection(event_loop *loop,
|
|||
* scheduler state. */
|
||||
LocalSchedulerClient *worker = new LocalSchedulerClient();
|
||||
worker->sock = new_socket;
|
||||
worker->registered = false;
|
||||
/* We don't know whether this is a worker or not, so just initialize is_worker
|
||||
* to false. */
|
||||
worker->is_worker = true;
|
||||
worker->client_id = NIL_WORKER_ID;
|
||||
worker->task_in_progress = NULL;
|
||||
worker->is_blocked = false;
|
||||
worker->pid = 0;
|
||||
|
@ -857,10 +944,21 @@ void signal_handler(int signal) {
|
|||
}
|
||||
}
|
||||
|
||||
/* End of the cleanup code. */
|
||||
|
||||
void handle_task_scheduled_callback(Task *original_task,
|
||||
void *subscribe_context) {
|
||||
LocalSchedulerState *state = (LocalSchedulerState *) subscribe_context;
|
||||
TaskSpec *spec = Task_task_spec(original_task);
|
||||
|
||||
/* If the driver for this task has been removed, then don't bother telling the
|
||||
* scheduling algorithm. */
|
||||
WorkerID driver_id = TaskSpec_driver_id(spec);
|
||||
if (!is_driver_alive(state, driver_id)) {
|
||||
LOG_DEBUG("Ignoring scheduled task for removed driver.")
|
||||
return;
|
||||
}
|
||||
|
||||
if (ActorID_equal(TaskSpec_actor_id(spec), NIL_ACTOR_ID)) {
|
||||
/* This task does not involve an actor. Handle it normally. */
|
||||
handle_task_scheduled(state, state->algorithm_state, spec,
|
||||
|
@ -884,10 +982,17 @@ void handle_task_scheduled_callback(Task *original_task,
|
|||
* for creating the actor.
|
||||
* @return Void.
|
||||
*/
|
||||
void handle_actor_creation_callback(ActorInfo info, void *context) {
|
||||
ActorID actor_id = info.actor_id;
|
||||
DBClientID local_scheduler_id = info.local_scheduler_id;
|
||||
void handle_actor_creation_callback(ActorID actor_id,
|
||||
WorkerID driver_id,
|
||||
DBClientID local_scheduler_id,
|
||||
void *context) {
|
||||
LocalSchedulerState *state = (LocalSchedulerState *) context;
|
||||
|
||||
/* If the driver has been removed, don't bother doing anything. */
|
||||
if (state->removed_drivers.count(driver_id) == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Make sure the actor entry is not already present in the actor map table.
|
||||
* TODO(rkn): We will need to remove this check to handle the case where the
|
||||
* corresponding publish is retried and the case in which a task that creates
|
||||
|
@ -897,9 +1002,10 @@ void handle_actor_creation_callback(ActorInfo info, void *context) {
|
|||
* Currently this is never removed (except when the local scheduler state is
|
||||
* deleted). */
|
||||
ActorMapEntry entry;
|
||||
entry.actor_id = actor_id;
|
||||
entry.local_scheduler_id = local_scheduler_id;
|
||||
entry.driver_id = driver_id;
|
||||
state->actor_mapping[actor_id] = entry;
|
||||
|
||||
/* If this local scheduler is responsible for the actor, then start a new
|
||||
* worker for the actor. */
|
||||
if (DBClientID_equal(local_scheduler_id, get_db_client_id(state->db))) {
|
||||
|
@ -960,6 +1066,11 @@ void start_server(const char *node_ip_address,
|
|||
actor_notification_table_subscribe(
|
||||
g_state->db, handle_actor_creation_callback, g_state, NULL);
|
||||
}
|
||||
/* Subscribe to notifications about removed drivers. */
|
||||
if (g_state->db != NULL) {
|
||||
driver_table_subscribe(g_state->db, handle_driver_removed_callback, g_state,
|
||||
NULL);
|
||||
}
|
||||
/* Create a timer for publishing information about the load on the local
|
||||
* scheduler to the local scheduler table. This message also serves as a
|
||||
* heartbeat. */
|
||||
|
|
|
@ -26,6 +26,14 @@ void new_client_connection(event_loop *loop,
|
|||
void *context,
|
||||
int events);
|
||||
|
||||
/**
|
||||
* Check if a driver is still alive.
|
||||
*
|
||||
* @param driver_id The ID of the driver.
|
||||
* @return True if the driver is still alive and false otherwise.
|
||||
*/
|
||||
bool is_driver_alive(WorkerID driver_id);
|
||||
|
||||
/**
|
||||
* This function can be called by the scheduling algorithm to assign a task
|
||||
* to a worker.
|
||||
|
|
|
@ -964,6 +964,9 @@ void handle_worker_available(LocalSchedulerState *state,
|
|||
void handle_worker_removed(LocalSchedulerState *state,
|
||||
SchedulingAlgorithmState *algorithm_state,
|
||||
LocalSchedulerClient *worker) {
|
||||
/* Make sure this is not an actor. */
|
||||
CHECK(ActorID_equal(worker->actor_id, NIL_ACTOR_ID));
|
||||
|
||||
/* Make sure that we remove the worker at most once. */
|
||||
int num_times_removed = 0;
|
||||
|
||||
|
@ -1141,6 +1144,59 @@ void handle_object_removed(LocalSchedulerState *state,
|
|||
}
|
||||
}
|
||||
|
||||
void handle_driver_removed(LocalSchedulerState *state,
|
||||
SchedulingAlgorithmState *algorithm_state,
|
||||
WorkerID driver_id) {
|
||||
/* Loop over fetch requests. This must be done before we clean up the waiting
|
||||
* task queue and the dispatch task queue because this map contains iterators
|
||||
* for those lists, which will be invalidated when we clean up those lists.*/
|
||||
for (auto it = algorithm_state->remote_objects.begin();
|
||||
it != algorithm_state->remote_objects.end();) {
|
||||
/* Loop over the tasks that are waiting for this object and remove the tasks
|
||||
* for the removed driver. */
|
||||
auto task_it_it = it->second.dependent_tasks.begin();
|
||||
while (task_it_it != it->second.dependent_tasks.end()) {
|
||||
/* If the dependent task was a task for the removed driver, remove it from
|
||||
* this vector. */
|
||||
TaskSpec *spec = (*task_it_it)->spec;
|
||||
if (WorkerID_equal(TaskSpec_driver_id(spec), driver_id)) {
|
||||
task_it_it = it->second.dependent_tasks.erase(task_it_it);
|
||||
} else {
|
||||
task_it_it++;
|
||||
}
|
||||
}
|
||||
/* If there are no more dependent tasks for this object, then remove the
|
||||
* ObjectEntry. */
|
||||
if (it->second.dependent_tasks.size() == 0) {
|
||||
it = algorithm_state->remote_objects.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
/* Remove this driver's tasks from the waiting task queue. */
|
||||
auto it = algorithm_state->waiting_task_queue->begin();
|
||||
while (it != algorithm_state->waiting_task_queue->end()) {
|
||||
if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) {
|
||||
it = algorithm_state->waiting_task_queue->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
/* Remove this driver's tasks from the dispatch task queue. */
|
||||
it = algorithm_state->dispatch_task_queue->begin();
|
||||
while (it != algorithm_state->dispatch_task_queue->end()) {
|
||||
if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) {
|
||||
it = algorithm_state->dispatch_task_queue->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO(rkn): Should we clean up the actor data structures? */
|
||||
}
|
||||
|
||||
int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state) {
|
||||
return algorithm_state->waiting_task_queue->size();
|
||||
}
|
||||
|
|
|
@ -230,6 +230,20 @@ void handle_worker_unblocked(LocalSchedulerState *state,
|
|||
SchedulingAlgorithmState *algorithm_state,
|
||||
LocalSchedulerClient *worker);
|
||||
|
||||
/**
|
||||
* Process the fact that a driver has been removed. This will remove all of the
|
||||
* tasks for that driver from the scheduling algorithm's internal data
|
||||
* structures.
|
||||
*
|
||||
* @param state The state of the local scheduler.
|
||||
* @param algorithm_state State maintained by the scheduling algorithm.
|
||||
* @param driver_id The ID of the driver that was removed.
|
||||
* @return Void.
|
||||
*/
|
||||
void handle_driver_removed(LocalSchedulerState *state,
|
||||
SchedulingAlgorithmState *algorithm_state,
|
||||
WorkerID driver_id);
|
||||
|
||||
/**
|
||||
* This function fetches queued task's missing object dependencies. It is
|
||||
* called every LOCAL_SCHEDULER_FETCH_TIMEOUT_MILLISECONDS.
|
||||
|
|
|
@ -8,7 +8,9 @@
|
|||
#include "utarray.h"
|
||||
#include "uthash.h"
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
/* These are needed to define the UT_arrays. */
|
||||
|
@ -17,8 +19,8 @@ extern UT_icd task_ptr_icd;
|
|||
/** This struct is used to maintain a mapping from actor IDs to the ID of the
|
||||
* local scheduler that is responsible for the actor. */
|
||||
struct ActorMapEntry {
|
||||
/** The ID of the actor. This is used as a key in the hash table. */
|
||||
ActorID actor_id;
|
||||
/** The ID of the driver that created the actor. */
|
||||
WorkerID driver_id;
|
||||
/** The ID of the local scheduler that is responsible for the actor. */
|
||||
DBClientID local_scheduler_id;
|
||||
};
|
||||
|
@ -47,7 +49,10 @@ struct LocalSchedulerState {
|
|||
/** List of workers available to this node. This is used to free the worker
|
||||
* structs when we free the scheduler state and also to access the worker
|
||||
* structs in the tests. */
|
||||
std::vector<LocalSchedulerClient *> workers;
|
||||
std::list<LocalSchedulerClient *> workers;
|
||||
/** A set of driver IDs corresponding to drivers that have been removed. This
|
||||
* is used to make sure we don't execute any tasks belong to dead drivers. */
|
||||
std::unordered_set<WorkerID, UniqueIDHasher> removed_drivers;
|
||||
/** List of the process IDs for child processes (workers) started by the
|
||||
* local scheduler that have not sent a REGISTER_PID message yet. */
|
||||
std::vector<pid_t> child_pids;
|
||||
|
@ -75,6 +80,13 @@ struct LocalSchedulerState {
|
|||
struct LocalSchedulerClient {
|
||||
/** The socket used to communicate with the client. */
|
||||
int sock;
|
||||
/** True if the client has registered and false otherwise. */
|
||||
bool registered;
|
||||
/** True if the client is a worker and false if it is a driver. */
|
||||
bool is_worker;
|
||||
/** The worker ID if the client is a worker and the driver ID if the client is
|
||||
* a driver. */
|
||||
WorkerID client_id;
|
||||
/** A pointer to the task object that is currently running on this client. If
|
||||
* no task is running on the worker, this will be NULL. This is used to
|
||||
* update the task table. */
|
||||
|
|
|
@ -64,9 +64,7 @@ static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) {
|
|||
for (int i = 0; i < num_mock_workers; ++i) {
|
||||
new_client_connection(mock->loop, mock->local_scheduler_fd,
|
||||
(void *) mock->local_scheduler_state, 0);
|
||||
|
||||
LocalSchedulerClient *worker = mock->local_scheduler_state->workers[i];
|
||||
|
||||
LocalSchedulerClient *worker = mock->local_scheduler_state->workers.back();
|
||||
process_message(mock->local_scheduler_state->loop, worker->sock, worker, 0);
|
||||
}
|
||||
}
|
||||
|
@ -644,7 +642,7 @@ TEST start_kill_workers_test(void) {
|
|||
ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(),
|
||||
num_workers);
|
||||
/* Make sure that the new worker registers its process ID. */
|
||||
worker = local_scheduler->local_scheduler_state->workers[num_workers - 1];
|
||||
worker = local_scheduler->local_scheduler_state->workers.back();
|
||||
process_message(local_scheduler->local_scheduler_state->loop, worker->sock,
|
||||
worker, 0);
|
||||
ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0);
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
@ -85,8 +86,8 @@ class DockerRunner(object):
|
|||
else:
|
||||
return m.group(1)
|
||||
|
||||
def _start_head_node(self, docker_image, mem_size, shm_size,
|
||||
development_mode):
|
||||
def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus,
|
||||
num_gpus, development_mode):
|
||||
"""Start the Ray head node inside a docker container."""
|
||||
mem_arg = ["--memory=" + mem_size] if mem_size else []
|
||||
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
|
||||
|
@ -94,10 +95,15 @@ class DockerRunner(object):
|
|||
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")]
|
||||
if development_mode else [])
|
||||
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
|
||||
volume_arg +
|
||||
[docker_image, "/ray/scripts/start_ray.sh",
|
||||
"--head", "--redis-port=6379"],
|
||||
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
|
||||
[docker_image, "/ray/scripts/start_ray.sh", "--head",
|
||||
"--redis-port=6379",
|
||||
"--num-cpus={}".format(num_cpus),
|
||||
"--num-gpus={}".format(num_gpus)])
|
||||
print("Starting head node with command:{}".format(command))
|
||||
|
||||
proc = subprocess.Popen(command,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
container_id = self._get_container_id(stdout_data)
|
||||
|
@ -105,33 +111,34 @@ class DockerRunner(object):
|
|||
raise RuntimeError("Failed to find container ID.")
|
||||
self.head_container_id = container_id
|
||||
self.head_container_ip = self._get_container_ip(container_id)
|
||||
print("start_node", {"container_id": container_id,
|
||||
"is_head": True,
|
||||
"shm_size": shm_size,
|
||||
"ip_address": self.head_container_ip})
|
||||
return container_id
|
||||
|
||||
def _start_worker_node(self, docker_image, mem_size, shm_size):
|
||||
def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus,
|
||||
num_gpus, development_mode):
|
||||
"""Start a Ray worker node inside a docker container."""
|
||||
mem_arg = ["--memory=" + mem_size] if mem_size else []
|
||||
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
|
||||
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
|
||||
["--shm-size=" + shm_size, docker_image,
|
||||
"/ray/scripts/start_ray.sh",
|
||||
"--redis-address={:s}:6379".format(
|
||||
self.head_container_ip)],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
volume_arg = (["-v",
|
||||
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")]
|
||||
if development_mode else [])
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
|
||||
["--shm-size=" + shm_size, docker_image,
|
||||
"/ray/scripts/start_ray.sh",
|
||||
"--redis-address={:s}:6379".format(self.head_container_ip),
|
||||
"--num-cpus={}".format(num_cpus),
|
||||
"--num-gpus={}".format(num_gpus)])
|
||||
print("Starting worker node with command:{}".format(command))
|
||||
proc = subprocess.Popen(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
container_id = self._get_container_id(stdout_data)
|
||||
if container_id is None:
|
||||
raise RuntimeError("Failed to find container id")
|
||||
self.worker_container_ids.append(container_id)
|
||||
print("start_node", {"container_id": container_id,
|
||||
"is_head": False,
|
||||
"shm_size": shm_size})
|
||||
|
||||
def start_ray(self, docker_image, mem_size, shm_size, num_nodes,
|
||||
development_mode):
|
||||
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
|
||||
num_nodes=None, num_cpus=None, num_gpus=None,
|
||||
development_mode=None):
|
||||
"""Start a Ray cluster within docker.
|
||||
|
||||
This starts one docker container running the head node and num_nodes - 1
|
||||
|
@ -146,15 +153,23 @@ class DockerRunner(object):
|
|||
with. This will be passed into `docker run` as the `--shm-size` flag.
|
||||
num_nodes: The number of nodes to use in the cluster (this counts the
|
||||
head node as well).
|
||||
num_cpus: A list of the number of CPUs to start each node with.
|
||||
num_gpus: A list of the number of GPUs to start each node with.
|
||||
development_mode: True if you want to mount the local copy of
|
||||
test/jenkins_test on the head node so we can avoid rebuilding docker
|
||||
images during development.
|
||||
"""
|
||||
assert len(num_cpus) == num_nodes
|
||||
assert len(num_gpus) == num_nodes
|
||||
|
||||
# Launch the head node.
|
||||
self._start_head_node(docker_image, mem_size, shm_size, development_mode)
|
||||
self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0],
|
||||
num_gpus[0], development_mode)
|
||||
# Start the worker nodes.
|
||||
for _ in range(num_nodes - 1):
|
||||
self._start_worker_node(docker_image, mem_size, shm_size)
|
||||
for i in range(num_nodes - 1):
|
||||
self._start_worker_node(docker_image, mem_size, shm_size,
|
||||
num_cpus[1 + i], num_gpus[1 + i],
|
||||
development_mode)
|
||||
|
||||
def _stop_node(self, container_id):
|
||||
"""Stop a node in the Ray cluster."""
|
||||
|
@ -181,32 +196,51 @@ class DockerRunner(object):
|
|||
for container_id in self.worker_container_ids:
|
||||
self._stop_node(container_id)
|
||||
|
||||
def run_test(self, test_script, run_in_docker=False):
|
||||
def run_test(self, test_script, num_drivers, driver_locations=None):
|
||||
"""Run a test script.
|
||||
|
||||
Run a test using the Ray cluster.
|
||||
|
||||
Args:
|
||||
test_script: The test script to run.
|
||||
run_in_docker: If true then the test script will be run in a docker
|
||||
container. If false, it will be run regularly.
|
||||
num_drivers: The number of copies of the test script to run.
|
||||
driver_locations: A list of the indices of the containers that the
|
||||
different copies of the test script should be run on. If this is None,
|
||||
then the containers will be chosen randomly.
|
||||
|
||||
Returns:
|
||||
A dictionary with information about the test script run.
|
||||
"""
|
||||
print("Starting to run test script {}.".format(test_script))
|
||||
proc = subprocess.Popen(["docker", "exec", self.head_container_id,
|
||||
"/bin/bash", "-c",
|
||||
"RAY_REDIS_ADDRESS={}:6379 "
|
||||
"python {}".format(self.head_container_ip,
|
||||
test_script)],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout_data, stderr_data = wait_for_output(proc)
|
||||
print("STDOUT:")
|
||||
print(stdout_data)
|
||||
print("STDERR:")
|
||||
print(stderr_data)
|
||||
return {"success": proc.returncode == 0, "return_code": proc.returncode}
|
||||
all_container_ids = [self.head_container_id] + self.worker_container_ids
|
||||
if driver_locations is None:
|
||||
driver_locations = [np.random.randint(0, len(all_container_ids))
|
||||
for _ in range(num_drivers)]
|
||||
|
||||
# Start the different drivers.
|
||||
driver_processes = []
|
||||
for i in range(len(driver_locations)):
|
||||
# Get the container ID to run the ith driver in.
|
||||
container_id = all_container_ids[driver_locations[i]]
|
||||
command = ["docker", "exec", container_id, "/bin/bash", "-c",
|
||||
("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python {}"
|
||||
.format(self.head_container_ip, i, test_script))]
|
||||
print("Starting driver with command {}.".format(test_script))
|
||||
# Start the driver.
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
driver_processes.append(p)
|
||||
|
||||
# Wait for the drivers to finish.
|
||||
results = []
|
||||
for p in driver_processes:
|
||||
stdout_data, stderr_data = wait_for_output(p)
|
||||
print("STDOUT:")
|
||||
print(stdout_data)
|
||||
print("STDERR:")
|
||||
print(stderr_data)
|
||||
results.append({"success": p.returncode == 0,
|
||||
"return_code": p.returncode})
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -218,23 +252,49 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--shm-size", default="1G", help="shared memory size")
|
||||
parser.add_argument("--num-nodes", default=1, type=int,
|
||||
help="number of nodes to use in the cluster")
|
||||
parser.add_argument("--num-cpus", type=str,
|
||||
help=("a comma separated list of values representing "
|
||||
"the number of CPUs to start each node with"))
|
||||
parser.add_argument("--num-gpus", type=str,
|
||||
help=("a comma separated list of values representing "
|
||||
"the number of GPUs to start each node with"))
|
||||
parser.add_argument("--num-drivers", default=1, type=int,
|
||||
help="number of drivers to run")
|
||||
parser.add_argument("--driver-locations", type=str,
|
||||
help=("a comma separated list of indices of the "
|
||||
"containers to run the drivers in"))
|
||||
parser.add_argument("--test-script", required=True, help="test script")
|
||||
parser.add_argument("--development-mode", action="store_true",
|
||||
help="use local copies of the test scripts")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse the number of CPUs and GPUs to use for each worker.
|
||||
num_nodes = args.num_nodes
|
||||
num_cpus = ([int(i) for i in args.num_cpus.split(",")]
|
||||
if args.num_cpus is not None else num_nodes * [10])
|
||||
num_gpus = ([int(i) for i in args.num_gpus.split(",")]
|
||||
if args.num_gpus is not None else num_nodes * [0])
|
||||
|
||||
d = DockerRunner()
|
||||
d.start_ray(mem_size=args.mem_size, shm_size=args.shm_size,
|
||||
num_nodes=args.num_nodes, docker_image=args.docker_image,
|
||||
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
|
||||
shm_size=args.shm_size, num_nodes=num_nodes,
|
||||
num_cpus=num_cpus, num_gpus=num_gpus,
|
||||
development_mode=args.development_mode)
|
||||
try:
|
||||
run_result = d.run_test(args.test_script)
|
||||
run_results = d.run_test(args.test_script, args.num_drivers,
|
||||
driver_locations=args.driver_locations)
|
||||
finally:
|
||||
d.stop_ray()
|
||||
|
||||
if "success" in run_result and run_result["success"]:
|
||||
print("RESULT: Test {} succeeded.".format(args.test_script))
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("RESULT: Test {} failed.".format(args.test_script))
|
||||
any_failed = False
|
||||
for run_result in run_results:
|
||||
if "success" in run_result and run_result["success"]:
|
||||
print("RESULT: Test {} succeeded.".format(args.test_script))
|
||||
else:
|
||||
print("RESULT: Test {} failed.".format(args.test_script))
|
||||
any_failed = True
|
||||
|
||||
if any_failed:
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
|
80
test/jenkins_tests/multi_node_tests/many_drivers_test.py
Normal file
80
test/jenkins_tests/multi_node_tests/many_drivers_test.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.test.multi_node_tests import (_wait_for_nodes_to_join,
|
||||
_broadcast_event,
|
||||
_wait_for_event)
|
||||
|
||||
# This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for
|
||||
# a total of 61 GPUs. It should be run with a large number of drivers (e.g.,
|
||||
# 100). At most 10 drivers will run at a time, and each driver will use at most
|
||||
# 5 GPUs (this is ceil(61 / 15), which guarantees that we will always be able
|
||||
# to make progress).
|
||||
total_num_nodes = 5
|
||||
max_concurrent_drivers = 15
|
||||
num_gpus_per_driver = 5
|
||||
|
||||
|
||||
@ray.actor(num_gpus=1)
|
||||
class Actor1(object):
|
||||
def __init__(self):
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
def check_ids(self):
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
|
||||
def driver(redis_address, driver_index):
|
||||
"""The script for driver 0.
|
||||
|
||||
This driver should create five actors that each use one GPU and some actors
|
||||
that use no GPUs. After a while, it should exit.
|
||||
"""
|
||||
ray.init(redis_address=redis_address)
|
||||
|
||||
# Wait for all the nodes to join the cluster.
|
||||
_wait_for_nodes_to_join(total_num_nodes)
|
||||
|
||||
# Limit the number of drivers running concurrently.
|
||||
for i in range(driver_index - max_concurrent_drivers + 1):
|
||||
_wait_for_event("DRIVER_{}_DONE".format(i), redis_address)
|
||||
|
||||
def try_to_create_actor(actor_class, timeout=100):
|
||||
# Try to create an actor, but allow failures while we wait for the monitor
|
||||
# to release the resources for the removed drivers.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
actor = actor_class()
|
||||
except Exception as e:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
return actor
|
||||
# If we are here, then we timed out while looping.
|
||||
raise Exception("Timed out while trying to create actor.")
|
||||
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = []
|
||||
for _ in range(num_gpus_per_driver):
|
||||
actors_one_gpu.append(try_to_create_actor(Actor1))
|
||||
|
||||
for _ in range(100):
|
||||
ray.get([actor.check_ids() for actor in actors_one_gpu])
|
||||
|
||||
_broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
|
||||
redis_address = os.environ["RAY_REDIS_ADDRESS"]
|
||||
print("Driver {} started at {}.".format(driver_index, time.time()))
|
||||
|
||||
# In this test, all drivers will run the same script.
|
||||
driver(redis_address, driver_index)
|
||||
|
||||
print("Driver {} finished at {}.".format(driver_index, time.time()))
|
153
test/jenkins_tests/multi_node_tests/remove_driver_test.py
Normal file
153
test/jenkins_tests/multi_node_tests/remove_driver_test.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.test.multi_node_tests import (_wait_for_nodes_to_join,
|
||||
_broadcast_event,
|
||||
_wait_for_event)
|
||||
|
||||
# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a
|
||||
# total of 10 GPUs. It shoudl be run with 3 drivers.
|
||||
total_num_nodes = 5
|
||||
|
||||
|
||||
@ray.actor
|
||||
class Actor0(object):
|
||||
def __init__(self):
|
||||
assert len(ray.get_gpu_ids()) == 0
|
||||
|
||||
def check_ids(self):
|
||||
assert len(ray.get_gpu_ids()) == 0
|
||||
|
||||
|
||||
@ray.actor(num_gpus=1)
|
||||
class Actor1(object):
|
||||
def __init__(self):
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
def check_ids(self):
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
|
||||
@ray.actor(num_gpus=2)
|
||||
class Actor2(object):
|
||||
def __init__(self):
|
||||
assert len(ray.get_gpu_ids()) == 2
|
||||
|
||||
def check_ids(self):
|
||||
assert len(ray.get_gpu_ids()) == 2
|
||||
|
||||
|
||||
def driver_0(redis_address):
|
||||
"""The script for driver 0.
|
||||
|
||||
This driver should create five actors that each use one GPU and some actors
|
||||
that use no GPUs. After a while, it should exit.
|
||||
"""
|
||||
ray.init(redis_address=redis_address)
|
||||
|
||||
# Wait for all the nodes to join the cluster.
|
||||
_wait_for_nodes_to_join(total_num_nodes)
|
||||
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = [Actor1() for _ in range(5)]
|
||||
# Create some actors that don't require any GPUs.
|
||||
actors_no_gpus = [Actor0() for _ in range(5)]
|
||||
|
||||
for _ in range(1000):
|
||||
ray.get([actor.check_ids() for actor in actors_one_gpu])
|
||||
ray.get([actor.check_ids() for actor in actors_no_gpus])
|
||||
|
||||
_broadcast_event("DRIVER_0_DONE", redis_address)
|
||||
|
||||
|
||||
def driver_1(redis_address):
|
||||
"""The script for driver 1.
|
||||
|
||||
This driver should create one actor that uses two GPUs, three actors that
|
||||
each use one GPU (the one requiring two must be created first), and some
|
||||
actors that don't use any GPUs. After a while, it should exit.
|
||||
"""
|
||||
ray.init(redis_address=redis_address)
|
||||
|
||||
# Wait for all the nodes to join the cluster.
|
||||
_wait_for_nodes_to_join(total_num_nodes)
|
||||
|
||||
# Create an actor that requires two GPUs.
|
||||
actors_two_gpus = [Actor2() for _ in range(1)]
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = [Actor1() for _ in range(3)]
|
||||
# Create some actors that don't require any GPUs.
|
||||
actors_no_gpus = [Actor0() for _ in range(5)]
|
||||
|
||||
for _ in range(1000):
|
||||
ray.get([actor.check_ids() for actor in actors_two_gpus])
|
||||
ray.get([actor.check_ids() for actor in actors_one_gpu])
|
||||
ray.get([actor.check_ids() for actor in actors_no_gpus])
|
||||
|
||||
_broadcast_event("DRIVER_1_DONE", redis_address)
|
||||
|
||||
|
||||
def driver_2(redis_address):
|
||||
"""The script for driver 2.
|
||||
|
||||
This driver should wait for the first two drivers to finish. Then it should
|
||||
create some actors that use a total of ten GPUs.
|
||||
"""
|
||||
ray.init(redis_address=redis_address)
|
||||
|
||||
_wait_for_event("DRIVER_0_DONE", redis_address)
|
||||
_wait_for_event("DRIVER_1_DONE", redis_address)
|
||||
|
||||
def try_to_create_actor(actor_class, timeout=20):
|
||||
# Try to create an actor, but allow failures while we wait for the monitor
|
||||
# to release the resources for the removed drivers.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
actor = actor_class()
|
||||
except Exception as e:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
return actor
|
||||
# If we are here, then we timed out while looping.
|
||||
raise Exception("Timed out while trying to create actor.")
|
||||
|
||||
# Create some actors that require two GPUs.
|
||||
actors_two_gpus = []
|
||||
for _ in range(3):
|
||||
actors_two_gpus.append(try_to_create_actor(Actor2))
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = []
|
||||
for _ in range(4):
|
||||
actors_one_gpu.append(try_to_create_actor(Actor1))
|
||||
# Create some actors that don't require any GPUs.
|
||||
actors_no_gpus = [Actor0() for _ in range(5)]
|
||||
|
||||
for _ in range(1000):
|
||||
ray.get([actor.check_ids() for actor in actors_two_gpus])
|
||||
ray.get([actor.check_ids() for actor in actors_one_gpu])
|
||||
ray.get([actor.check_ids() for actor in actors_no_gpus])
|
||||
|
||||
_broadcast_event("DRIVER_2_DONE", redis_address)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
|
||||
redis_address = os.environ["RAY_REDIS_ADDRESS"]
|
||||
print("Driver {} started at {}.".format(driver_index, time.time()))
|
||||
|
||||
if driver_index == 0:
|
||||
driver_0(redis_address)
|
||||
elif driver_index == 1:
|
||||
driver_1(redis_address)
|
||||
elif driver_index == 2:
|
||||
driver_2(redis_address)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
print("Driver {} finished at {}.".format(driver_index, time.time()))
|
|
@ -15,7 +15,11 @@ def f():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(redis_address=os.environ["RAY_REDIS_ADDRESS"])
|
||||
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
|
||||
redis_address = os.environ["RAY_REDIS_ADDRESS"]
|
||||
print("Driver {} started at {}.".format(driver_index, time.time()))
|
||||
|
||||
ray.init(redis_address=redis_address)
|
||||
# Check that tasks are scheduled on all nodes.
|
||||
num_attempts = 30
|
||||
for i in range(num_attempts):
|
||||
|
@ -26,3 +30,5 @@ if __name__ == "__main__":
|
|||
if len(counts) == 5:
|
||||
break
|
||||
assert len(counts) == 5
|
||||
|
||||
print("Driver {} finished at {}.".format(driver_index, time.time()))
|
||||
|
|
|
@ -13,6 +13,20 @@ python $ROOT_DIR/multi_node_docker_test.py \
|
|||
--num-nodes=5 \
|
||||
--test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py
|
||||
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=5 \
|
||||
--num-gpus=0,1,2,3,4 \
|
||||
--num-drivers=3 \
|
||||
--test-script=/ray/test/jenkins_tests/multi_node_tests/remove_driver_test.py
|
||||
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=5 \
|
||||
--num-gpus=0,0,5,6,50 \
|
||||
--num-drivers=100 \
|
||||
--test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py
|
||||
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=1 \
|
||||
|
|
|
@ -1433,8 +1433,7 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
|
||||
client_table = ray.global_state.client_table()
|
||||
node_ip_address = ray.worker.global_worker.node_ip_address
|
||||
self.assertEqual(len(client_table[node_ip_address]), 2)
|
||||
self.assertEqual(len(client_table[":"]), 1)
|
||||
self.assertEqual(len(client_table[node_ip_address]), 3)
|
||||
manager_client = [c for c in client_table[node_ip_address]
|
||||
if c["ClientType"] == "plasma_manager"][0]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue