mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Make Monitor remove dead Redis entries from exiting drivers. (#994)
* WIP: removing OL, OI, TT on client exit; no saving yet. * ray_redis_module.cc: update header comment. * Cleanup: just the removal. * Reformat via yapf: use pep8 style instead of google. * Checkpoint addressing comments (partially) * Add 'b' marker before strings (py3 compat) * Add MonitorTest. * Use `isort` to sort imports. * Remove some loggings * Fix flake8 noqa marker runtest.py * Try to separate tests out to monitor_test.py * Rework cleanup algorithm: correct logic * Extend tests to cover multi-shard cases * Add some small comments and formatting changes.
This commit is contained in:
parent
6e9657e696
commit
5a50e80b63
7 changed files with 698 additions and 174 deletions
190
.style.yapf
Normal file
190
.style.yapf
Normal file
|
@ -0,0 +1,190 @@
|
|||
[style]
|
||||
# Align closing bracket with visual indentation.
|
||||
align_closing_bracket_with_visual_indent=True
|
||||
|
||||
# Allow dictionary keys to exist on multiple lines. For example:
|
||||
#
|
||||
# x = {
|
||||
# ('this is the first element of a tuple',
|
||||
# 'this is the second element of a tuple'):
|
||||
# value,
|
||||
# }
|
||||
allow_multiline_dictionary_keys=False
|
||||
|
||||
# Allow lambdas to be formatted on more than one line.
|
||||
allow_multiline_lambdas=False
|
||||
|
||||
# Insert a blank line before a class-level docstring.
|
||||
blank_line_before_class_docstring=False
|
||||
|
||||
# Insert a blank line before a 'def' or 'class' immediately nested
|
||||
# within another 'def' or 'class'. For example:
|
||||
#
|
||||
# class Foo:
|
||||
# # <------ this blank line
|
||||
# def method():
|
||||
# ...
|
||||
blank_line_before_nested_class_or_def=False
|
||||
|
||||
# Do not split consecutive brackets. Only relevant when
|
||||
# dedent_closing_brackets is set. For example:
|
||||
#
|
||||
# call_func_that_takes_a_dict(
|
||||
# {
|
||||
# 'key1': 'value1',
|
||||
# 'key2': 'value2',
|
||||
# }
|
||||
# )
|
||||
#
|
||||
# would reformat to:
|
||||
#
|
||||
# call_func_that_takes_a_dict({
|
||||
# 'key1': 'value1',
|
||||
# 'key2': 'value2',
|
||||
# })
|
||||
coalesce_brackets=False
|
||||
|
||||
# The column limit.
|
||||
column_limit=79
|
||||
|
||||
# Indent width used for line continuations.
|
||||
continuation_indent_width=4
|
||||
|
||||
# Put closing brackets on a separate line, dedented, if the bracketed
|
||||
# expression can't fit in a single line. Applies to all kinds of brackets,
|
||||
# including function definitions and calls. For example:
|
||||
#
|
||||
# config = {
|
||||
# 'key1': 'value1',
|
||||
# 'key2': 'value2',
|
||||
# } # <--- this bracket is dedented and on a separate line
|
||||
#
|
||||
# time_series = self.remote_client.query_entity_counters(
|
||||
# entity='dev3246.region1',
|
||||
# key='dns.query_latency_tcp',
|
||||
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
|
||||
# start_ts=now()-timedelta(days=3),
|
||||
# end_ts=now(),
|
||||
# ) # <--- this bracket is dedented and on a separate line
|
||||
dedent_closing_brackets=False
|
||||
|
||||
# Place each dictionary entry onto its own line.
|
||||
each_dict_entry_on_separate_line=True
|
||||
|
||||
# The regex for an i18n comment. The presence of this comment stops
|
||||
# reformatting of that line, because the comments are required to be
|
||||
# next to the string they translate.
|
||||
i18n_comment=
|
||||
|
||||
# The i18n function call names. The presence of this function stops
|
||||
# reformattting on that line, because the string it has cannot be moved
|
||||
# away from the i18n comment.
|
||||
i18n_function_call=
|
||||
|
||||
# Indent the dictionary value if it cannot fit on the same line as the
|
||||
# dictionary key. For example:
|
||||
#
|
||||
# config = {
|
||||
# 'key1':
|
||||
# 'value1',
|
||||
# 'key2': value1 +
|
||||
# value2,
|
||||
# }
|
||||
indent_dictionary_value=False
|
||||
|
||||
# The number of columns to use for indentation.
|
||||
indent_width=4
|
||||
|
||||
# Join short lines into one line. E.g., single line 'if' statements.
|
||||
join_multiple_lines=True
|
||||
|
||||
# Do not include spaces around selected binary operators. For example:
|
||||
#
|
||||
# 1 + 2 * 3 - 4 / 5
|
||||
#
|
||||
# will be formatted as follows when configured with a value "*,/":
|
||||
#
|
||||
# 1 + 2*3 - 4/5
|
||||
#
|
||||
no_spaces_around_selected_binary_operators=set([])
|
||||
|
||||
# Use spaces around default or named assigns.
|
||||
spaces_around_default_or_named_assign=False
|
||||
|
||||
# Use spaces around the power operator.
|
||||
spaces_around_power_operator=False
|
||||
|
||||
# The number of spaces required before a trailing comment.
|
||||
spaces_before_comment=2
|
||||
|
||||
# Insert a space between the ending comma and closing bracket of a list,
|
||||
# etc.
|
||||
space_between_ending_comma_and_closing_bracket=True
|
||||
|
||||
# Split before arguments if the argument list is terminated by a
|
||||
# comma.
|
||||
split_arguments_when_comma_terminated=False
|
||||
|
||||
# Set to True to prefer splitting before '&', '|' or '^' rather than
|
||||
# after.
|
||||
split_before_bitwise_operator=True
|
||||
|
||||
# Split before a dictionary or set generator (comp_for). For example, note
|
||||
# the split before the 'for':
|
||||
#
|
||||
# foo = {
|
||||
# variable: 'Hello world, have a nice day!'
|
||||
# for variable in bar if variable != 42
|
||||
# }
|
||||
split_before_dict_set_generator=True
|
||||
|
||||
# If an argument / parameter list is going to be split, then split before
|
||||
# the first argument.
|
||||
split_before_first_argument=False
|
||||
|
||||
# Set to True to prefer splitting before 'and' or 'or' rather than
|
||||
# after.
|
||||
split_before_logical_operator=True
|
||||
|
||||
# Split named assignments onto individual lines.
|
||||
split_before_named_assigns=True
|
||||
|
||||
# The penalty for splitting right after the opening bracket.
|
||||
split_penalty_after_opening_bracket=30
|
||||
|
||||
# The penalty for splitting the line after a unary operator.
|
||||
split_penalty_after_unary_operator=10000
|
||||
|
||||
# The penalty for splitting right before an if expression.
|
||||
split_penalty_before_if_expr=0
|
||||
|
||||
# The penalty of splitting the line around the '&', '|', and '^'
|
||||
# operators.
|
||||
split_penalty_bitwise_operator=300
|
||||
|
||||
# The penalty for characters over the column limit.
|
||||
split_penalty_excess_character=4500
|
||||
|
||||
# The penalty incurred by adding a line split to the unwrapped line. The
|
||||
# more line splits added the higher the penalty.
|
||||
split_penalty_for_added_line_split=30
|
||||
|
||||
# The penalty of splitting a list of "import as" names. For example:
|
||||
#
|
||||
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
|
||||
# long_argument_2,
|
||||
# long_argument_3)
|
||||
#
|
||||
# would reformat to something like:
|
||||
#
|
||||
# from a_very_long_or_indented_module_name_yada_yad import (
|
||||
# long_argument_1, long_argument_2, long_argument_3)
|
||||
split_penalty_import_names=0
|
||||
|
||||
# The penalty of splitting the line around the 'and' and 'or'
|
||||
# operators.
|
||||
split_penalty_logical_operator=300
|
||||
|
||||
# Use the Tab character for indentation.
|
||||
use_tabs=False
|
||||
|
|
@ -117,5 +117,6 @@ script:
|
|||
- python test/component_failures_test.py
|
||||
- python test/multi_node_test.py
|
||||
- python test/recursion_test.py
|
||||
- python test/monitor_test.py
|
||||
|
||||
- python -m pytest python/ray/rllib/test/test_catalog.py
|
||||
|
|
|
@ -52,12 +52,18 @@ TASK_STATUS_MAPPING = {
|
|||
class GlobalState(object):
|
||||
"""A class used to interface with the Ray control state.
|
||||
|
||||
# TODO(zongheng): In the future move this to use Ray's redis module in the
|
||||
# backend to cut down on # of request RPCs.
|
||||
|
||||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
# The redis server storing metadata, such as function table, client
|
||||
# table, log files, event logs, workers/actions info.
|
||||
self.redis_client = None
|
||||
# A list of redis shards, storing the object table & task table.
|
||||
self.redis_clients = None
|
||||
|
||||
def _check_connected(self):
|
||||
|
|
|
@ -3,22 +3,22 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import json
|
||||
import logging
|
||||
import redis
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import ray
|
||||
import ray.utils
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_object_id, binary_to_hex, hex_to_binary
|
||||
from ray.worker import NIL_ACTOR_ID
|
||||
|
||||
import redis
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import \
|
||||
SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskInfo import TaskInfo
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
|
||||
from ray.worker import NIL_ACTOR_ID
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
# common/common.h
|
||||
|
@ -26,17 +26,24 @@ HEARTBEAT_TIMEOUT_MILLISECONDS = 100
|
|||
NUM_HEARTBEATS_TIMEOUT = 100
|
||||
DB_CLIENT_ID_SIZE = 20
|
||||
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
||||
|
||||
# common/task.h
|
||||
TASK_STATUS_LOST = 32
|
||||
|
||||
# common/state/redis.cc
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
OBJECT_PREFIX = "OL:"
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
OBJECT_INFO_PREFIX = b"OI:"
|
||||
OBJECT_LOCATION_PREFIX = b"OL:"
|
||||
TASK_TABLE_PREFIX = b"TT:"
|
||||
DB_CLIENT_PREFIX = b"CL:"
|
||||
DB_CLIENT_TABLE_NAME = b"db_clients"
|
||||
|
||||
# local_scheduler/local_scheduler.h
|
||||
LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler"
|
||||
|
||||
# plasma/plasma_manager.cc
|
||||
PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
|
||||
|
||||
|
@ -69,12 +76,13 @@ class Monitor(object):
|
|||
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
|
||||
managers that were up at one point and have died since then.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port,
|
||||
db=0)
|
||||
self.redis = redis.StrictRedis(
|
||||
host=redis_address, port=redis_port, db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
|
@ -109,8 +117,9 @@ class Monitor(object):
|
|||
info["local_scheduler_id"] in self.dead_local_schedulers):
|
||||
# Choose a new local scheduler to run the actor.
|
||||
local_scheduler_id = ray.utils.select_local_scheduler(
|
||||
info["driver_id"], self.state.local_schedulers(),
|
||||
info["num_gpus"], self.redis)
|
||||
info["driver_id"],
|
||||
self.state.local_schedulers(), info["num_gpus"],
|
||||
self.redis)
|
||||
import sys
|
||||
sys.stdout.flush()
|
||||
# The new local scheduler should not be the same as the old
|
||||
|
@ -121,8 +130,9 @@ class Monitor(object):
|
|||
# Announce to all of the local schedulers that the actor should
|
||||
# be recreated on this new local scheduler.
|
||||
ray.utils.publish_actor_creation(
|
||||
hex_to_binary(actor_id), hex_to_binary(info["driver_id"]),
|
||||
local_scheduler_id, True, self.redis)
|
||||
hex_to_binary(actor_id),
|
||||
hex_to_binary(info["driver_id"]), local_scheduler_id, True,
|
||||
self.redis)
|
||||
log.info("Actor {} for driver {} was on dead local scheduler "
|
||||
"{}. It is being recreated on local scheduler {}"
|
||||
.format(actor_id, info["driver_id"],
|
||||
|
@ -160,7 +170,7 @@ class Monitor(object):
|
|||
# The dummy object should exist on at most one plasma
|
||||
# manager, the manager associated with the local scheduler
|
||||
# that died.
|
||||
assert(len(manager_ids) <= 1)
|
||||
assert len(manager_ids) <= 1
|
||||
# Remove the dummy object from the plasma manager
|
||||
# associated with the dead local scheduler, if any.
|
||||
for manager in manager_ids:
|
||||
|
@ -175,7 +185,8 @@ class Monitor(object):
|
|||
# task as lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
key, "RAY.TASK_TABLE_UPDATE",
|
||||
hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
|
@ -238,7 +249,7 @@ class Monitor(object):
|
|||
log.debug("Subscribed to {}, data was {}".format(channel, data))
|
||||
self.subscribed[channel] = True
|
||||
|
||||
def db_client_notification_handler(self, channel, data):
|
||||
def db_client_notification_handler(self, unused_channel, data):
|
||||
"""Handle a notification from the db_client table from Redis.
|
||||
|
||||
This handler processes notifications from the db_client table.
|
||||
|
@ -247,9 +258,8 @@ class Monitor(object):
|
|||
the associated state in the state tables should be handled by the
|
||||
caller.
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data,
|
||||
0))
|
||||
notification_object = (SubscribeToDBClientTableReply.
|
||||
GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
@ -271,7 +281,7 @@ class Monitor(object):
|
|||
# already dead.
|
||||
del self.live_plasma_managers[db_client_id]
|
||||
|
||||
def plasma_manager_heartbeat_handler(self, channel, data):
|
||||
def plasma_manager_heartbeat_handler(self, unused_channel, data):
|
||||
"""Handle a plasma manager heartbeat from Redis.
|
||||
|
||||
This resets the number of heartbeats that we've missed from this plasma
|
||||
|
@ -283,7 +293,134 @@ class Monitor(object):
|
|||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
|
||||
def driver_removed_handler(self, channel, data):
|
||||
def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
|
||||
"""Collect IDs of control-state entries for a driver from a shard.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver.
|
||||
redis_shard_index: The index of the Redis shard to query.
|
||||
|
||||
Returns:
|
||||
Lists of IDs: (returned_object_ids, task_ids, put_objects). The
|
||||
first two are relevant to the driver and are safe to delete.
|
||||
The last contains all "put" objects in this redis shard; each
|
||||
element is an (object_id, corresponding task_id) pair.
|
||||
"""
|
||||
# TODO(zongheng): consider adding save & restore functionalities.
|
||||
redis = self.state.redis_clients[redis_shard_index]
|
||||
task_table_infos = {} # task id -> TaskInfo messages
|
||||
|
||||
# Scan the task table & filter to get the list of tasks belong to this
|
||||
# driver. Use a cursor in order not to block the redis shards.
|
||||
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
|
||||
entry = redis.hgetall(key)
|
||||
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
|
||||
if driver_id != task_info.DriverId():
|
||||
# Ignore tasks that aren't from this driver.
|
||||
continue
|
||||
task_table_infos[task_info.TaskId()] = task_info
|
||||
|
||||
# Get the list of objects returned by these tasks. Note these might
|
||||
# not belong to this redis shard.
|
||||
returned_object_ids = []
|
||||
for task_info in task_table_infos.values():
|
||||
returned_object_ids.extend([
|
||||
task_info.Returns(i) for i in range(task_info.ReturnsLength())
|
||||
])
|
||||
|
||||
# Also record all the ray.put()'d objects.
|
||||
put_objects = []
|
||||
for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
|
||||
entry = redis.hgetall(key)
|
||||
if entry[b"is_put"] == "0":
|
||||
continue
|
||||
object_id = key.split(OBJECT_INFO_PREFIX)[1]
|
||||
task_id = entry[b"task"]
|
||||
put_objects.append((object_id, task_id))
|
||||
|
||||
return returned_object_ids, task_table_infos.keys(), put_objects
|
||||
|
||||
def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
|
||||
redis = self.state.redis_clients[shard_index]
|
||||
# Clean up (in the future, save) entries for non-empty objects.
|
||||
object_ids_locs = set()
|
||||
object_ids_infos = set()
|
||||
for object_id in object_ids:
|
||||
# OL.
|
||||
obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
|
||||
if obj_loc:
|
||||
object_ids_locs.add(object_id)
|
||||
# OI.
|
||||
obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
|
||||
if obj_info:
|
||||
object_ids_infos.add(object_id)
|
||||
|
||||
# Form the redis keys to delete.
|
||||
keys = [TASK_TABLE_PREFIX + k for k in task_ids]
|
||||
keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
|
||||
keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])
|
||||
|
||||
if not keys:
|
||||
return
|
||||
# Remove with best effort.
|
||||
num_deleted = redis.delete(*keys)
|
||||
log.info(
|
||||
"Removed {} dead redis entries of the driver from redis shard {}.".
|
||||
format(num_deleted, shard_index))
|
||||
if num_deleted != len(keys):
|
||||
log.warning(
|
||||
"Failed to remove {} relevant redis entries"
|
||||
" from redis shard {}.".format(len(keys) - num_deleted))
|
||||
|
||||
def _clean_up_entries_for_driver(self, driver_id):
|
||||
"""Remove this driver's object/task entries from all redis shards.
|
||||
|
||||
Specifically, removes control-state entries of:
|
||||
* all objects (OI and OL entries) created by `ray.put()` from the
|
||||
driver
|
||||
* all tasks belonging to the driver.
|
||||
"""
|
||||
# TODO(zongheng): handle function_table, client_table, log_files --
|
||||
# these are in the metadata redis server, not in the shards.
|
||||
driver_object_ids = []
|
||||
driver_task_ids = []
|
||||
all_put_objects = []
|
||||
|
||||
# Collect relevant ids.
|
||||
# TODO(zongheng): consider parallelizing this loop.
|
||||
for shard_index in range(len(self.state.redis_clients)):
|
||||
returned_object_ids, task_ids, put_objects = \
|
||||
self._entries_for_driver_in_shard(driver_id, shard_index)
|
||||
driver_object_ids.extend(returned_object_ids)
|
||||
driver_task_ids.extend(task_ids)
|
||||
all_put_objects.extend(put_objects)
|
||||
|
||||
# For the put objects, keep those from relevant tasks.
|
||||
driver_task_ids_set = set(driver_task_ids)
|
||||
for object_id, task_id in all_put_objects:
|
||||
if task_id in driver_task_ids_set:
|
||||
driver_object_ids.append(object_id)
|
||||
|
||||
# Partition IDs and distribute to shards.
|
||||
object_ids_per_shard = defaultdict(list)
|
||||
task_ids_per_shard = defaultdict(list)
|
||||
|
||||
def ToShardIndex(index):
|
||||
return binary_to_object_id(index).redis_shard_hash() % len(
|
||||
self.state.redis_clients)
|
||||
|
||||
for object_id in driver_object_ids:
|
||||
object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
|
||||
for task_id in driver_task_ids:
|
||||
task_ids_per_shard[ToShardIndex(task_id)].append(task_id)
|
||||
|
||||
# TODO(zongheng): consider parallelizing this loop.
|
||||
for shard_index in range(len(self.state.redis_clients)):
|
||||
self._clean_up_entries_from_shard(
|
||||
object_ids_per_shard[shard_index],
|
||||
task_ids_per_shard[shard_index], shard_index)
|
||||
|
||||
def driver_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
|
@ -291,8 +428,8 @@ class Monitor(object):
|
|||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed."
|
||||
.format(binary_to_hex(driver_id)))
|
||||
log.info(
|
||||
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
|
||||
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
|
@ -302,6 +439,8 @@ class Monitor(object):
|
|||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
|
||||
self._clean_up_entries_for_driver(driver_id)
|
||||
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
|
@ -321,8 +460,8 @@ class Monitor(object):
|
|||
|
||||
result = pipe.hget(local_scheduler_id,
|
||||
"gpus_in_use")
|
||||
gpus_in_use = (dict() if result is None
|
||||
else json.loads(result))
|
||||
gpus_in_use = (dict() if result is None else
|
||||
json.loads(result))
|
||||
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
|
@ -345,9 +484,9 @@ class Monitor(object):
|
|||
continue
|
||||
|
||||
log.info("Driver {} is returning GPU IDs {} to local "
|
||||
"scheduler {}.".format(binary_to_hex(driver_id),
|
||||
num_gpus_returned,
|
||||
local_scheduler_id))
|
||||
"scheduler {}.".format(
|
||||
binary_to_hex(driver_id), num_gpus_returned,
|
||||
local_scheduler_id))
|
||||
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
|
@ -371,22 +510,23 @@ class Monitor(object):
|
|||
# to an initial subscription request.
|
||||
message_handler = self.subscribe_handler
|
||||
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a notification that a driver was removed.
|
||||
log.info("message-handler: driver_removed_handler")
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
assert (message_handler is not None)
|
||||
message_handler(channel, data)
|
||||
|
||||
def run(self):
|
||||
|
@ -439,8 +579,8 @@ class Monitor(object):
|
|||
# Handle plasma managers that timed out during this round.
|
||||
plasma_manager_ids = list(self.live_plasma_managers.keys())
|
||||
for plasma_manager_id in plasma_manager_ids:
|
||||
if ((self.live_plasma_managers
|
||||
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
|
||||
if ((self.live_plasma_managers[plasma_manager_id]) >=
|
||||
NUM_HEARTBEATS_TIMEOUT):
|
||||
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
|
||||
# Remove the plasma manager from the managers whose
|
||||
# heartbeats we're tracking.
|
||||
|
@ -465,8 +605,11 @@ class Monitor(object):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
|
|
|
@ -8,22 +8,25 @@
|
|||
|
||||
#include "common_protocol.h"
|
||||
|
||||
/**
|
||||
* Various tables are maintained in redis:
|
||||
*
|
||||
* == OBJECT TABLE ==
|
||||
*
|
||||
* This consists of two parts:
|
||||
* - The object location table, indexed by OL:object_id, which is the set of
|
||||
* plasma manager indices that have access to the object.
|
||||
* - The object info table, indexed by OI:object_id, which is a hashmap with key
|
||||
* "hash" for the hash of the object and key "data_size" for the size of the
|
||||
* object in bytes.
|
||||
*
|
||||
* == TASK TABLE ==
|
||||
*
|
||||
* TODO(pcm): Fill this out.
|
||||
*/
|
||||
// Various tables are maintained in redis:
|
||||
//
|
||||
// == OBJECT TABLE ==
|
||||
//
|
||||
// This consists of two parts:
|
||||
// - The object location table, indexed by OL:object_id, which is the set of
|
||||
// plasma manager indices that have access to the object.
|
||||
// (In redis this is represented by a zset (sorted set).)
|
||||
//
|
||||
// - The object info table, indexed by OI:object_id, which is a hashmap of:
|
||||
// "hash" -> the hash of the object,
|
||||
// "data_size" -> the size of the object in bytes,
|
||||
// "task" -> the task ID that generated this object.
|
||||
// "is_put" -> 0 or 1.
|
||||
//
|
||||
// == TASK TABLE ==
|
||||
//
|
||||
// TODO(pcm): Fill this out.
|
||||
//
|
||||
|
||||
#define OBJECT_INFO_PREFIX "OI:"
|
||||
#define OBJECT_LOCATION_PREFIX "OL:"
|
||||
|
|
93
test/monitor_test.py
Normal file
93
test/monitor_test.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
class MonitorTest(unittest.TestCase):
|
||||
def _testCleanupOnDriverExit(self, num_redis_shards):
|
||||
stdout = subprocess.check_output([
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--num-redis-shards",
|
||||
str(num_redis_shards),
|
||||
]).decode("ascii")
|
||||
lines = [m.strip() for m in stdout.split("\n")]
|
||||
init_cmd = [m for m in lines if m.startswith("ray.init")]
|
||||
self.assertEqual(1, len(init_cmd))
|
||||
redis_address = init_cmd[0].split("redis_address=\"")[-1][:-2]
|
||||
|
||||
def StateSummary():
|
||||
obj_tbl_len = len(ray.global_state.object_table())
|
||||
task_tbl_len = len(ray.global_state.task_table())
|
||||
func_tbl_len = len(ray.global_state.function_table())
|
||||
return obj_tbl_len, task_tbl_len, func_tbl_len
|
||||
|
||||
def Driver(success):
|
||||
success.value = True
|
||||
# Start driver.
|
||||
ray.init(redis_address=redis_address)
|
||||
summary_start = StateSummary()
|
||||
if (0, 1) != summary_start[:2]:
|
||||
success.value = False
|
||||
|
||||
# Two new objects.
|
||||
ray.get(ray.put(1111))
|
||||
ray.get(ray.put(1111))
|
||||
if (2, 1, summary_start[2]) != StateSummary():
|
||||
success.value = False
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
ray.put(1111) # Yet another object.
|
||||
return 1111 # A returned object as well.
|
||||
|
||||
# 1 new function.
|
||||
if (2, 1, summary_start[2] + 1) != StateSummary():
|
||||
success.value = False
|
||||
|
||||
ray.get(f.remote())
|
||||
if (4, 2, summary_start[2] + 1) != StateSummary():
|
||||
success.value = False
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
success = multiprocessing.Value('b', False)
|
||||
driver = multiprocessing.Process(target=Driver, args=(success, ))
|
||||
driver.start()
|
||||
# Wait for client to exit.
|
||||
driver.join()
|
||||
time.sleep(5)
|
||||
|
||||
# Just make sure Driver() is run and succeeded. Note(rkn), if the below
|
||||
# assertion starts failing, then the issue may be that the summary
|
||||
# values computed in the Driver function are being updated slowly and
|
||||
# so the call to StateSummary() is getting outdated values. This could
|
||||
# be fixed by looping until StateSummary() returns the desired values.
|
||||
self.assertTrue(success.value)
|
||||
# Check that objects, tasks, and functions are cleaned up.
|
||||
ray.init(redis_address=redis_address)
|
||||
# The assertion below can fail if the monitor is too slow to clean up
|
||||
# the global state.
|
||||
self.assertEqual((0, 1), StateSummary()[:2])
|
||||
|
||||
ray.worker.cleanup()
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
def testCleanupOnDriverExitSingleRedisShard(self):
|
||||
self._testCleanupOnDriverExit(num_redis_shards=1)
|
||||
|
||||
def testCleanupOnDriverExitManyRedisShards(self):
|
||||
self._testCleanupOnDriverExit(num_redis_shards=5)
|
||||
self._testCleanupOnDriverExit(num_redis_shards=31)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
324
test/runtest.py
324
test/runtest.py
|
@ -1,18 +1,17 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
import numpy as np
|
||||
import os
|
||||
import ray
|
||||
import re
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.test.test_functions as test_functions
|
||||
import ray.test.test_utils
|
||||
|
||||
|
@ -21,11 +20,11 @@ if sys.version_info >= (3, 0):
|
|||
|
||||
|
||||
def assert_equal(obj1, obj2):
|
||||
module_numpy = (type(obj1).__module__ == np.__name__ or
|
||||
type(obj2).__module__ == np.__name__)
|
||||
module_numpy = (type(obj1).__module__ == np.__name__
|
||||
or type(obj2).__module__ == np.__name__)
|
||||
if module_numpy:
|
||||
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or
|
||||
(hasattr(obj2, "shape") and obj2.shape == ()))
|
||||
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ())
|
||||
or (hasattr(obj2, "shape") and obj2.shape == ()))
|
||||
if empty_shape:
|
||||
# This is a special case because currently np.testing.assert_equal
|
||||
# fails because we do not properly handle different numerical
|
||||
|
@ -36,13 +35,11 @@ def assert_equal(obj1, obj2):
|
|||
np.testing.assert_equal(obj1, obj2)
|
||||
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
|
||||
special_keys = ["_pytype_"]
|
||||
assert (set(list(obj1.__dict__.keys()) + special_keys) ==
|
||||
set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
|
||||
"and {} are "
|
||||
"different."
|
||||
.format(
|
||||
obj1,
|
||||
obj2))
|
||||
assert (set(list(obj1.__dict__.keys()) + special_keys) == set(
|
||||
list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
|
||||
"and {} are "
|
||||
"different.".format(
|
||||
obj1, obj2))
|
||||
for key in obj1.__dict__.keys():
|
||||
if key not in special_keys:
|
||||
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
|
||||
|
@ -52,49 +49,76 @@ def assert_equal(obj1, obj2):
|
|||
assert_equal(obj1[key], obj2[key])
|
||||
elif type(obj1) is list or type(obj2) is list:
|
||||
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
|
||||
"different lengths."
|
||||
.format(obj1, obj2))
|
||||
"different lengths.".format(
|
||||
obj1, obj2))
|
||||
for i in range(len(obj1)):
|
||||
assert_equal(obj1[i], obj2[i])
|
||||
elif type(obj1) is tuple or type(obj2) is tuple:
|
||||
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with "
|
||||
"different lengths."
|
||||
.format(obj1, obj2))
|
||||
"different lengths.".format(
|
||||
obj1, obj2))
|
||||
for i in range(len(obj1)):
|
||||
assert_equal(obj1[i], obj2[i])
|
||||
elif (ray.serialization.is_named_tuple(type(obj1)) or
|
||||
ray.serialization.is_named_tuple(type(obj2))):
|
||||
elif (ray.serialization.is_named_tuple(type(obj1))
|
||||
or ray.serialization.is_named_tuple(type(obj2))):
|
||||
assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples "
|
||||
"with different lengths."
|
||||
.format(obj1, obj2))
|
||||
"with different lengths.".format(
|
||||
obj1, obj2))
|
||||
for i in range(len(obj1)):
|
||||
assert_equal(obj1[i], obj2[i])
|
||||
else:
|
||||
assert obj1 == obj2, "Objects {} and {} are different.".format(obj1,
|
||||
obj2)
|
||||
assert obj1 == obj2, "Objects {} and {} are different.".format(
|
||||
obj1, obj2)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
|
||||
else:
|
||||
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821
|
||||
|
||||
PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999,
|
||||
[1 << 100, [1 << 100]], "a", string.printable, "\u262F",
|
||||
u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True,
|
||||
False, [], (), {}, np.int8(3), np.int32(4), np.int64(5),
|
||||
np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.9),
|
||||
np.float64(1.9), np.zeros([100, 100]),
|
||||
np.random.normal(size=[100, 100]), np.array(["hi", 3]),
|
||||
np.array(["hi", 3], dtype=object)] + long_extras
|
||||
long_extras = [
|
||||
long(0), # noqa: E501,F821
|
||||
np.array([
|
||||
["hi", u"hi"],
|
||||
[1.3, long(1)] # noqa: E501,F821
|
||||
])
|
||||
]
|
||||
|
||||
PRIMITIVE_OBJECTS = [
|
||||
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a",
|
||||
string.printable, "\u262F", u"hello world", u"\xff\xfe\x9c\x001\x000\x00",
|
||||
None, True, False, [], (), {},
|
||||
np.int8(3),
|
||||
np.int32(4),
|
||||
np.int64(5),
|
||||
np.uint8(3),
|
||||
np.uint32(4),
|
||||
np.uint64(5),
|
||||
np.float32(1.9),
|
||||
np.float64(1.9),
|
||||
np.zeros([100, 100]),
|
||||
np.random.normal(size=[100, 100]),
|
||||
np.array(["hi", 3]),
|
||||
np.array(["hi", 3], dtype=object)
|
||||
] + long_extras
|
||||
|
||||
COMPLEX_OBJECTS = [
|
||||
[[[[[[[[[[[[]]]]]]]]]]]],
|
||||
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
|
||||
{"obj{}".format(i): np.random.normal(size=[100, 100])
|
||||
for i in range(10)},
|
||||
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
|
||||
# (): {(): {}}}}}}}}}}}}},
|
||||
((((((((((),),),),),),),),),),
|
||||
{"a": {"b": {"c": {"d": {}}}}}]
|
||||
(
|
||||
(((((((((), ), ), ), ), ), ), ), ), ),
|
||||
{
|
||||
"a": {
|
||||
"b": {
|
||||
"c": {
|
||||
"d": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class Foo(object):
|
||||
|
@ -141,21 +165,32 @@ Point = namedtuple("Point", ["x", "y"])
|
|||
NamedTupleExample = namedtuple("Example",
|
||||
"field1, field2, field3, field4, field5")
|
||||
|
||||
CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
|
||||
Foo(), Bar(), Baz(), # Qux(), SubQux(),
|
||||
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])]
|
||||
CUSTOM_OBJECTS = [
|
||||
Exception("Test object."),
|
||||
CustomError(),
|
||||
Point(11, y=22),
|
||||
Foo(),
|
||||
Bar(),
|
||||
Baz(), # Qux(), SubQux(),
|
||||
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])
|
||||
]
|
||||
|
||||
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
|
||||
|
||||
LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS]
|
||||
TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS]
|
||||
TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS]
|
||||
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
|
||||
# otherwise this seems to fail on Mac OS X on Travis.
|
||||
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS
|
||||
if (obj.__hash__ is not None and
|
||||
type(obj).__module__ != "numpy")] +
|
||||
[{0: obj} for obj in BASE_OBJECTS] +
|
||||
[{Foo(123): Foo(456)}])
|
||||
DICT_OBJECTS = (
|
||||
[{
|
||||
obj: obj
|
||||
} for obj in PRIMITIVE_OBJECTS
|
||||
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
|
||||
0:
|
||||
obj
|
||||
} for obj in BASE_OBJECTS] + [{
|
||||
Foo(123): Foo(456)
|
||||
}])
|
||||
|
||||
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
|
||||
|
||||
|
@ -171,7 +206,6 @@ except AttributeError:
|
|||
|
||||
|
||||
class SerializationTest(unittest.TestCase):
|
||||
|
||||
def testRecursiveObjects(self):
|
||||
ray.init(num_workers=0)
|
||||
|
||||
|
@ -253,15 +287,15 @@ class SerializationTest(unittest.TestCase):
|
|||
|
||||
|
||||
class WorkerTest(unittest.TestCase):
|
||||
|
||||
def testPythonWorkers(self):
|
||||
# Test the codepath for starting workers from the Python script,
|
||||
# instead of the local scheduler. This codepath is for debugging
|
||||
# purposes only.
|
||||
num_workers = 4
|
||||
ray.worker._init(num_workers=num_workers,
|
||||
start_workers_from_local_scheduler=False,
|
||||
start_ray_local=True)
|
||||
ray.worker._init(
|
||||
num_workers=num_workers,
|
||||
start_workers_from_local_scheduler=False,
|
||||
start_ray_local=True)
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
|
@ -275,13 +309,13 @@ class WorkerTest(unittest.TestCase):
|
|||
ray.init(num_workers=0)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6
|
||||
value_before = i * 10**6
|
||||
objectid = ray.put(value_before)
|
||||
value_after = ray.get(objectid)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6 * 1.0
|
||||
value_before = i * 10**6 * 1.0
|
||||
objectid = ray.put(value_before)
|
||||
value_after = ray.get(objectid)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
@ -302,7 +336,6 @@ class WorkerTest(unittest.TestCase):
|
|||
|
||||
|
||||
class APITest(unittest.TestCase):
|
||||
|
||||
def init_ray(self, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
@ -318,6 +351,7 @@ class APITest(unittest.TestCase):
|
|||
# throws an exception.
|
||||
class TempClass(object):
|
||||
pass
|
||||
|
||||
ray.get(ray.put(TempClass()))
|
||||
|
||||
# Note that the below actually returns a dictionary and not a
|
||||
|
@ -525,14 +559,14 @@ class APITest(unittest.TestCase):
|
|||
return x, y, args
|
||||
|
||||
self.assertEqual(ray.get(f1.remote()), ())
|
||||
self.assertEqual(ray.get(f1.remote(1)), (1,))
|
||||
self.assertEqual(ray.get(f1.remote(1)), (1, ))
|
||||
self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3))
|
||||
with self.assertRaises(Exception):
|
||||
f2.remote()
|
||||
with self.assertRaises(Exception):
|
||||
f2.remote(1)
|
||||
self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ()))
|
||||
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,)))
|
||||
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3, )))
|
||||
self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4)))
|
||||
|
||||
def testNoArgs(self):
|
||||
|
@ -548,12 +582,14 @@ class APITest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
self.assertEqual(ray.get(f.remote(0)), 1)
|
||||
|
||||
# Test that we can redefine the remote function.
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 10
|
||||
|
||||
while True:
|
||||
val = ray.get(f.remote(0))
|
||||
self.assertTrue(val in [1, 10])
|
||||
|
@ -563,23 +599,29 @@ class APITest(unittest.TestCase):
|
|||
print("Still using old definition of f, trying again.")
|
||||
|
||||
# Test that we can close over plain old data.
|
||||
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60,
|
||||
{"a": np.zeros(3)}]
|
||||
data = [
|
||||
np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {
|
||||
"a": np.zeros(3)
|
||||
}
|
||||
]
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
return data
|
||||
|
||||
ray.get(g.remote())
|
||||
|
||||
# Test that we can close over modules.
|
||||
@ray.remote
|
||||
def h():
|
||||
return np.zeros([3, 5])
|
||||
|
||||
assert_equal(ray.get(h.remote()), np.zeros([3, 5]))
|
||||
|
||||
@ray.remote
|
||||
def j():
|
||||
return time.time()
|
||||
|
||||
ray.get(j.remote())
|
||||
|
||||
# Test that we can define remote functions that call other remote
|
||||
|
@ -595,6 +637,7 @@ class APITest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def m(x):
|
||||
return ray.get(l.remote(x))
|
||||
|
||||
self.assertEqual(ray.get(k.remote(1)), 2)
|
||||
self.assertEqual(ray.get(l.remote(1)), 2)
|
||||
self.assertEqual(ray.get(m.remote(1)), 2)
|
||||
|
@ -618,8 +661,12 @@ class APITest(unittest.TestCase):
|
|||
time.sleep(delay)
|
||||
return 1
|
||||
|
||||
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5),
|
||||
f.remote(0.5)]
|
||||
objectids = [
|
||||
f.remote(1.0),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5)
|
||||
]
|
||||
ready_ids, remaining_ids = ray.wait(objectids)
|
||||
self.assertEqual(len(ready_ids), 1)
|
||||
self.assertEqual(len(remaining_ids), 3)
|
||||
|
@ -627,17 +674,25 @@ class APITest(unittest.TestCase):
|
|||
self.assertEqual(set(ready_ids), set(objectids))
|
||||
self.assertEqual(remaining_ids, [])
|
||||
|
||||
objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5),
|
||||
f.remote(0.5)]
|
||||
objectids = [
|
||||
f.remote(0.5),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5)
|
||||
]
|
||||
start_time = time.time()
|
||||
ready_ids, remaining_ids = ray.wait(objectids, timeout=1750,
|
||||
num_returns=4)
|
||||
ready_ids, remaining_ids = ray.wait(
|
||||
objectids, timeout=1750, num_returns=4)
|
||||
self.assertLess(time.time() - start_time, 2)
|
||||
self.assertEqual(len(ready_ids), 3)
|
||||
self.assertEqual(len(remaining_ids), 1)
|
||||
ray.wait(objectids)
|
||||
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5),
|
||||
f.remote(0.5)]
|
||||
objectids = [
|
||||
f.remote(1.0),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5),
|
||||
f.remote(0.5)
|
||||
]
|
||||
start_time = time.time()
|
||||
ready_ids, remaining_ids = ray.wait(objectids, timeout=5000)
|
||||
self.assertTrue(time.time() - start_time < 5)
|
||||
|
@ -684,18 +739,22 @@ class APITest(unittest.TestCase):
|
|||
# is connected.
|
||||
def f(worker_info):
|
||||
sys.path.append(1)
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
def f(worker_info):
|
||||
sys.path.append(2)
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
def g(worker_info):
|
||||
sys.path.append(3)
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(g)
|
||||
|
||||
def f(worker_info):
|
||||
sys.path.append(4)
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
self.init_ray()
|
||||
|
@ -716,6 +775,7 @@ class APITest(unittest.TestCase):
|
|||
sys.path.pop()
|
||||
sys.path.pop()
|
||||
sys.path.pop()
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
def testRunningFunctionOnAllWorkers(self):
|
||||
|
@ -723,15 +783,18 @@ class APITest(unittest.TestCase):
|
|||
|
||||
def f(worker_info):
|
||||
sys.path.append("fake_directory")
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
@ray.remote
|
||||
def get_path1():
|
||||
return sys.path
|
||||
|
||||
self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1])
|
||||
|
||||
def f(worker_info):
|
||||
sys.path.pop(-1)
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
# Create a second remote function to guarantee that when we call
|
||||
|
@ -740,6 +803,7 @@ class APITest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def get_path2():
|
||||
return sys.path
|
||||
|
||||
self.assertTrue("fake_directory" not in ray.get(get_path2.remote()))
|
||||
|
||||
def testLoggingAPI(self):
|
||||
|
@ -751,8 +815,8 @@ class APITest(unittest.TestCase):
|
|||
keys = ray.worker.global_worker.redis_client.keys("event_log:*")
|
||||
res = []
|
||||
for key in keys:
|
||||
res.extend(ray.worker.global_worker.redis_client.zrange(key, 0,
|
||||
-1))
|
||||
res.extend(
|
||||
ray.worker.global_worker.redis_client.zrange(key, 0, -1))
|
||||
return res
|
||||
|
||||
def wait_for_num_events(num_events, timeout=10):
|
||||
|
@ -806,26 +870,31 @@ class APITest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def f():
|
||||
return 1
|
||||
|
||||
results1 = [f.remote() for _ in range(num_calls)]
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 2
|
||||
|
||||
results2 = [f.remote() for _ in range(num_calls)]
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 3
|
||||
|
||||
results3 = [f.remote() for _ in range(num_calls)]
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 4
|
||||
|
||||
results4 = [f.remote() for _ in range(num_calls)]
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 5
|
||||
|
||||
results5 = [f.remote() for _ in range(num_calls)]
|
||||
|
||||
self.assertEqual(ray.get(results1), num_calls * [1])
|
||||
|
@ -870,7 +939,6 @@ class APITest(unittest.TestCase):
|
|||
|
||||
|
||||
class APITestSharded(APITest):
|
||||
|
||||
def init_ray(self, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
@ -881,7 +949,6 @@ class APITestSharded(APITest):
|
|||
|
||||
|
||||
class PythonModeTest(unittest.TestCase):
|
||||
|
||||
def testPythonMode(self):
|
||||
reload(test_functions)
|
||||
ray.init(driver_mode=ray.PYTHON_MODE)
|
||||
|
@ -889,6 +956,7 @@ class PythonModeTest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def f():
|
||||
return np.ones([3, 4, 5])
|
||||
|
||||
xref = f.remote()
|
||||
# Remote functions should return by value.
|
||||
assert_equal(xref, np.ones([3, 4, 5]))
|
||||
|
@ -911,8 +979,8 @@ class PythonModeTest(unittest.TestCase):
|
|||
# first list and the remaining values as the second list
|
||||
num_returns = 5
|
||||
object_ids = [ray.put(i) for i in range(20)]
|
||||
ready, remaining = ray.wait(object_ids, num_returns=num_returns,
|
||||
timeout=None)
|
||||
ready, remaining = ray.wait(
|
||||
object_ids, num_returns=num_returns, timeout=None)
|
||||
assert_equal(ready, object_ids[:num_returns])
|
||||
assert_equal(remaining, object_ids[num_returns:])
|
||||
|
||||
|
@ -949,7 +1017,6 @@ class PythonModeTest(unittest.TestCase):
|
|||
|
||||
|
||||
class UtilsTest(unittest.TestCase):
|
||||
|
||||
def testCopyingDirectory(self):
|
||||
# The functionality being tested here is really multi-node
|
||||
# functionality, but this test just uses a single node.
|
||||
|
@ -999,7 +1066,6 @@ class UtilsTest(unittest.TestCase):
|
|||
|
||||
|
||||
class ResourcesTest(unittest.TestCase):
|
||||
|
||||
def testResourceConstraints(self):
|
||||
num_workers = 20
|
||||
ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2)
|
||||
|
@ -1012,9 +1078,13 @@ class ResourcesTest(unittest.TestCase):
|
|||
def get_worker_id():
|
||||
time.sleep(1)
|
||||
return sys.path[-1]
|
||||
|
||||
while True:
|
||||
if len(set(ray.get([get_worker_id.remote()
|
||||
for _ in range(num_workers)]))) == num_workers:
|
||||
if len(
|
||||
set(
|
||||
ray.get([
|
||||
get_worker_id.remote() for _ in range(num_workers)
|
||||
]))) == num_workers:
|
||||
break
|
||||
|
||||
time_buffer = 0.3
|
||||
|
@ -1088,9 +1158,13 @@ class ResourcesTest(unittest.TestCase):
|
|||
def get_worker_id():
|
||||
time.sleep(1)
|
||||
return sys.path[-1]
|
||||
|
||||
while True:
|
||||
if len(set(ray.get([get_worker_id.remote()
|
||||
for _ in range(num_workers)]))) == num_workers:
|
||||
if len(
|
||||
set(
|
||||
ray.get([
|
||||
get_worker_id.remote() for _ in range(num_workers)
|
||||
]))) == num_workers:
|
||||
break
|
||||
|
||||
@ray.remote(num_cpus=1, num_gpus=9)
|
||||
|
@ -1192,7 +1266,7 @@ class ResourcesTest(unittest.TestCase):
|
|||
|
||||
list_of_ids = ray.get([f1.remote() for _ in range(10)])
|
||||
set_of_ids = set([tuple(gpu_ids) for gpu_ids in list_of_ids])
|
||||
self.assertEqual(set_of_ids, set([(i,) for i in range(10)]))
|
||||
self.assertEqual(set_of_ids, set([(i, ) for i in range(10)]))
|
||||
|
||||
list_of_ids = ray.get([f2.remote(), f4.remote(), f4.remote()])
|
||||
all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids]
|
||||
|
@ -1218,11 +1292,12 @@ class ResourcesTest(unittest.TestCase):
|
|||
# This test will define a bunch of tasks that can only be assigned to
|
||||
# specific local schedulers, and we will check that they are assigned
|
||||
# to the correct local schedulers.
|
||||
address_info = ray.worker._init(start_ray_local=True,
|
||||
num_local_schedulers=3,
|
||||
num_workers=1,
|
||||
num_cpus=[100, 5, 10],
|
||||
num_gpus=[0, 5, 1])
|
||||
address_info = ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=3,
|
||||
num_workers=1,
|
||||
num_cpus=[100, 5, 10],
|
||||
num_gpus=[0, 5, 1])
|
||||
|
||||
# Define a bunch of remote functions that all return the socket name of
|
||||
# the plasma store. Since there is a one-to-one correspondence between
|
||||
|
@ -1284,8 +1359,10 @@ class ResourcesTest(unittest.TestCase):
|
|||
results.append(run_on_0_2.remote())
|
||||
return names, results
|
||||
|
||||
store_names = [object_store_address.name for object_store_address
|
||||
in address_info["object_store_addresses"]]
|
||||
store_names = [
|
||||
object_store_address.name
|
||||
for object_store_address in address_info["object_store_addresses"]
|
||||
]
|
||||
|
||||
def validate_names_and_results(names, results):
|
||||
for name, result in zip(names, ray.get(results)):
|
||||
|
@ -1296,8 +1373,9 @@ class ResourcesTest(unittest.TestCase):
|
|||
elif name == "run_on_2":
|
||||
self.assertIn(result, [store_names[2]])
|
||||
elif name == "run_on_0_1_2":
|
||||
self.assertIn(result, [store_names[0], store_names[1],
|
||||
store_names[2]])
|
||||
self.assertIn(result, [
|
||||
store_names[0], store_names[1], store_names[2]
|
||||
])
|
||||
elif name == "run_on_1_2":
|
||||
self.assertIn(result, [store_names[1], store_names[2]])
|
||||
elif name == "run_on_0_2":
|
||||
|
@ -1327,8 +1405,11 @@ class ResourcesTest(unittest.TestCase):
|
|||
ray.worker.cleanup()
|
||||
|
||||
def testCustomResources(self):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_cpus=3, num_custom_resource=[0, 1])
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_cpus=3,
|
||||
num_custom_resource=[0, 1])
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
|
@ -1373,13 +1454,12 @@ class ResourcesTest(unittest.TestCase):
|
|||
ray.get(ray.remote(num_custom_resource=2)(f).remote())
|
||||
ray.get(ray.remote(num_custom_resource=4)(f).remote())
|
||||
ray.get(ray.remote(num_custom_resource=8)(f).remote())
|
||||
ray.get(ray.remote(num_custom_resource=(10 ** 10))(f).remote())
|
||||
ray.get(ray.remote(num_custom_resource=(10**10))(f).remote())
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
||||
class WorkerPoolTests(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -1450,19 +1530,22 @@ class WorkerPoolTests(unittest.TestCase):
|
|||
|
||||
|
||||
class SchedulingAlgorithm(unittest.TestCase):
|
||||
|
||||
def attempt_to_load_balance(self, remote_function, args, total_tasks,
|
||||
num_local_schedulers, minimum_count,
|
||||
def attempt_to_load_balance(self,
|
||||
remote_function,
|
||||
args,
|
||||
total_tasks,
|
||||
num_local_schedulers,
|
||||
minimum_count,
|
||||
num_attempts=20):
|
||||
attempts = 0
|
||||
while attempts < num_attempts:
|
||||
locations = ray.get([remote_function.remote(*args)
|
||||
for _ in range(total_tasks)])
|
||||
locations = ray.get(
|
||||
[remote_function.remote(*args) for _ in range(total_tasks)])
|
||||
names = set(locations)
|
||||
counts = [locations.count(name) for name in names]
|
||||
print("Counts are {}.".format(counts))
|
||||
if (len(names) == num_local_schedulers and
|
||||
all([count >= minimum_count for count in counts])):
|
||||
if (len(names) == num_local_schedulers
|
||||
and all([count >= minimum_count for count in counts])):
|
||||
break
|
||||
attempts += 1
|
||||
self.assertLess(attempts, num_attempts)
|
||||
|
@ -1472,9 +1555,10 @@ class SchedulingAlgorithm(unittest.TestCase):
|
|||
# schedulers in a roughly equal manner.
|
||||
num_local_schedulers = 3
|
||||
num_cpus = 7
|
||||
ray.worker._init(start_ray_local=True,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=num_cpus)
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=num_cpus)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
|
@ -1492,8 +1576,10 @@ class SchedulingAlgorithm(unittest.TestCase):
|
|||
# dependencies.
|
||||
num_workers = 3
|
||||
num_local_schedulers = 3
|
||||
ray.worker._init(start_ray_local=True, num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers)
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers)
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
|
@ -1528,7 +1614,6 @@ def wait_for_num_objects(num_objects, timeout=10):
|
|||
|
||||
|
||||
class GlobalStateAPI(unittest.TestCase):
|
||||
|
||||
def testGlobalStateAPI(self):
|
||||
with self.assertRaises(Exception):
|
||||
ray.global_state.object_table()
|
||||
|
@ -1572,15 +1657,16 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
driver_id)
|
||||
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"],
|
||||
ID_SIZE * "ff")
|
||||
self.assertEqual((task_table[driver_task_id]["TaskSpec"]
|
||||
["ReturnObjectIDs"]),
|
||||
[])
|
||||
self.assertEqual(
|
||||
(task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), [])
|
||||
|
||||
client_table = ray.global_state.client_table()
|
||||
node_ip_address = ray.worker.global_worker.node_ip_address
|
||||
self.assertEqual(len(client_table[node_ip_address]), 3)
|
||||
manager_client = [c for c in client_table[node_ip_address]
|
||||
if c["ClientType"] == "plasma_manager"][0]
|
||||
manager_client = [
|
||||
c for c in client_table[node_ip_address]
|
||||
if c["ClientType"] == "plasma_manager"
|
||||
][0]
|
||||
|
||||
@ray.remote
|
||||
def f(*xs):
|
||||
|
@ -1624,8 +1710,8 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
while time.time() - start_time < timeout:
|
||||
object_table = ray.global_state.object_table()
|
||||
tables_ready = (
|
||||
object_table[x_id]["ManagerIDs"] is not None and
|
||||
object_table[result_id]["ManagerIDs"] is not None)
|
||||
object_table[x_id]["ManagerIDs"] is not None
|
||||
and object_table[result_id]["ManagerIDs"] is not None)
|
||||
if tables_ready:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
|
@ -1701,8 +1787,8 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
while time.time() - start_time < 10:
|
||||
profiles = ray.global_state.task_profiles(
|
||||
100, start=0, end=time.time())
|
||||
limited_profiles = ray.global_state.task_profiles(1, start=0,
|
||||
end=time.time())
|
||||
limited_profiles = ray.global_state.task_profiles(
|
||||
1, start=0, end=time.time())
|
||||
if len(profiles) == num_calls and len(limited_profiles) == 1:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
@ -1722,8 +1808,10 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
|
||||
def testWorkers(self):
|
||||
num_workers = 3
|
||||
ray.init(redirect_output=True, num_cpus=num_workers,
|
||||
num_workers=num_workers)
|
||||
ray.init(
|
||||
redirect_output=True,
|
||||
num_cpus=num_workers,
|
||||
num_workers=num_workers)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
|
|
Loading…
Add table
Reference in a new issue