Make Monitor remove dead Redis entries from exiting drivers. (#994)

* WIP: removing OL, OI, TT on client exit; no saving yet.

* ray_redis_module.cc: update header comment.

* Cleanup: just the removal.

* Reformat via yapf: use pep8 style instead of google.

* Checkpoint addressing comments (partially)

* Add 'b' marker before strings (py3 compat)

* Add MonitorTest.

* Use `isort` to sort imports.

* Remove some loggings

* Fix flake8 noqa marker runtest.py

* Try to separate tests out to monitor_test.py

* Rework cleanup algorithm: correct logic

* Extend tests to cover multi-shard cases

* Add some small comments and formatting changes.
This commit is contained in:
Zongheng Yang 2017-09-26 00:11:38 -07:00 committed by Robert Nishihara
parent 6e9657e696
commit 5a50e80b63
7 changed files with 698 additions and 174 deletions

190
.style.yapf Normal file
View file

@ -0,0 +1,190 @@
[style]
# Align closing bracket with visual indentation.
align_closing_bracket_with_visual_indent=True
# Allow dictionary keys to exist on multiple lines. For example:
#
# x = {
# ('this is the first element of a tuple',
# 'this is the second element of a tuple'):
# value,
# }
allow_multiline_dictionary_keys=False
# Allow lambdas to be formatted on more than one line.
allow_multiline_lambdas=False
# Insert a blank line before a class-level docstring.
blank_line_before_class_docstring=False
# Insert a blank line before a 'def' or 'class' immediately nested
# within another 'def' or 'class'. For example:
#
# class Foo:
# # <------ this blank line
# def method():
# ...
blank_line_before_nested_class_or_def=False
# Do not split consecutive brackets. Only relevant when
# dedent_closing_brackets is set. For example:
#
# call_func_that_takes_a_dict(
# {
# 'key1': 'value1',
# 'key2': 'value2',
# }
# )
#
# would reformat to:
#
# call_func_that_takes_a_dict({
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
# The column limit.
column_limit=79
# Indent width used for line continuations.
continuation_indent_width=4
# Put closing brackets on a separate line, dedented, if the bracketed
# expression can't fit in a single line. Applies to all kinds of brackets,
# including function definitions and calls. For example:
#
# config = {
# 'key1': 'value1',
# 'key2': 'value2',
# } # <--- this bracket is dedented and on a separate line
#
# time_series = self.remote_client.query_entity_counters(
# entity='dev3246.region1',
# key='dns.query_latency_tcp',
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
# start_ts=now()-timedelta(days=3),
# end_ts=now(),
# ) # <--- this bracket is dedented and on a separate line
dedent_closing_brackets=False
# Place each dictionary entry onto its own line.
each_dict_entry_on_separate_line=True
# The regex for an i18n comment. The presence of this comment stops
# reformatting of that line, because the comments are required to be
# next to the string they translate.
i18n_comment=
# The i18n function call names. The presence of this function stops
# reformattting on that line, because the string it has cannot be moved
# away from the i18n comment.
i18n_function_call=
# Indent the dictionary value if it cannot fit on the same line as the
# dictionary key. For example:
#
# config = {
# 'key1':
# 'value1',
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=False
# The number of columns to use for indentation.
indent_width=4
# Join short lines into one line. E.g., single line 'if' statements.
join_multiple_lines=True
# Do not include spaces around selected binary operators. For example:
#
# 1 + 2 * 3 - 4 / 5
#
# will be formatted as follows when configured with a value "*,/":
#
# 1 + 2*3 - 4/5
#
no_spaces_around_selected_binary_operators=set([])
# Use spaces around default or named assigns.
spaces_around_default_or_named_assign=False
# Use spaces around the power operator.
spaces_around_power_operator=False
# The number of spaces required before a trailing comment.
spaces_before_comment=2
# Insert a space between the ending comma and closing bracket of a list,
# etc.
space_between_ending_comma_and_closing_bracket=True
# Split before arguments if the argument list is terminated by a
# comma.
split_arguments_when_comma_terminated=False
# Set to True to prefer splitting before '&', '|' or '^' rather than
# after.
split_before_bitwise_operator=True
# Split before a dictionary or set generator (comp_for). For example, note
# the split before the 'for':
#
# foo = {
# variable: 'Hello world, have a nice day!'
# for variable in bar if variable != 42
# }
split_before_dict_set_generator=True
# If an argument / parameter list is going to be split, then split before
# the first argument.
split_before_first_argument=False
# Set to True to prefer splitting before 'and' or 'or' rather than
# after.
split_before_logical_operator=True
# Split named assignments onto individual lines.
split_before_named_assigns=True
# The penalty for splitting right after the opening bracket.
split_penalty_after_opening_bracket=30
# The penalty for splitting the line after a unary operator.
split_penalty_after_unary_operator=10000
# The penalty for splitting right before an if expression.
split_penalty_before_if_expr=0
# The penalty of splitting the line around the '&', '|', and '^'
# operators.
split_penalty_bitwise_operator=300
# The penalty for characters over the column limit.
split_penalty_excess_character=4500
# The penalty incurred by adding a line split to the unwrapped line. The
# more line splits added the higher the penalty.
split_penalty_for_added_line_split=30
# The penalty of splitting a list of "import as" names. For example:
#
# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
# long_argument_2,
# long_argument_3)
#
# would reformat to something like:
#
# from a_very_long_or_indented_module_name_yada_yad import (
# long_argument_1, long_argument_2, long_argument_3)
split_penalty_import_names=0
# The penalty of splitting the line around the 'and' and 'or'
# operators.
split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False

View file

@ -117,5 +117,6 @@ script:
- python test/component_failures_test.py
- python test/multi_node_test.py
- python test/recursion_test.py
- python test/monitor_test.py
- python -m pytest python/ray/rllib/test/test_catalog.py

View file

@ -52,12 +52,18 @@ TASK_STATUS_MAPPING = {
class GlobalState(object):
"""A class used to interface with the Ray control state.
# TODO(zongheng): In the future move this to use Ray's redis module in the
# backend to cut down on # of request RPCs.
Attributes:
redis_client: The redis client used to query the redis server.
redis_client: The redis client used to query the redis server.
"""
def __init__(self):
"""Create a GlobalState object."""
# The redis server storing metadata, such as function table, client
# table, log files, event logs, workers/actions info.
self.redis_client = None
# A list of redis shards, storing the object table & task table.
self.redis_clients = None
def _check_connected(self):

View file

@ -3,22 +3,22 @@ from __future__ import division
from __future__ import print_function
import argparse
from collections import Counter
import json
import logging
import redis
import time
from collections import Counter, defaultdict
import ray
import ray.utils
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_object_id, binary_to_hex, hex_to_binary
from ray.worker import NIL_ACTOR_ID
import redis
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.core.generated.TaskInfo import TaskInfo
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
from ray.worker import NIL_ACTOR_ID
# These variables must be kept in sync with the C codebase.
# common/common.h
@ -26,17 +26,24 @@ HEARTBEAT_TIMEOUT_MILLISECONDS = 100
NUM_HEARTBEATS_TIMEOUT = 100
DB_CLIENT_ID_SIZE = 20
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
# common/task.h
TASK_STATUS_LOST = 32
# common/state/redis.cc
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# common/redis_module/ray_redis_module.cc
OBJECT_PREFIX = "OL:"
DB_CLIENT_PREFIX = "CL:"
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
TASK_TABLE_PREFIX = b"TT:"
DB_CLIENT_PREFIX = b"CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
# local_scheduler/local_scheduler.h
LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler"
# plasma/plasma_manager.cc
PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
@ -69,12 +76,13 @@ class Monitor(object):
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
managers that were up at one point and have died since then.
"""
def __init__(self, redis_address, redis_port):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.redis = redis.StrictRedis(host=redis_address, port=redis_port,
db=0)
self.redis = redis.StrictRedis(
host=redis_address, port=redis_port, db=0)
# TODO(swang): Update pubsub client to use ray.experimental.state once
# subscriptions are implemented there.
self.subscribe_client = self.redis.pubsub()
@ -109,8 +117,9 @@ class Monitor(object):
info["local_scheduler_id"] in self.dead_local_schedulers):
# Choose a new local scheduler to run the actor.
local_scheduler_id = ray.utils.select_local_scheduler(
info["driver_id"], self.state.local_schedulers(),
info["num_gpus"], self.redis)
info["driver_id"],
self.state.local_schedulers(), info["num_gpus"],
self.redis)
import sys
sys.stdout.flush()
# The new local scheduler should not be the same as the old
@ -121,8 +130,9 @@ class Monitor(object):
# Announce to all of the local schedulers that the actor should
# be recreated on this new local scheduler.
ray.utils.publish_actor_creation(
hex_to_binary(actor_id), hex_to_binary(info["driver_id"]),
local_scheduler_id, True, self.redis)
hex_to_binary(actor_id),
hex_to_binary(info["driver_id"]), local_scheduler_id, True,
self.redis)
log.info("Actor {} for driver {} was on dead local scheduler "
"{}. It is being recreated on local scheduler {}"
.format(actor_id, info["driver_id"],
@ -160,7 +170,7 @@ class Monitor(object):
# The dummy object should exist on at most one plasma
# manager, the manager associated with the local scheduler
# that died.
assert(len(manager_ids) <= 1)
assert len(manager_ids) <= 1
# Remove the dummy object from the plasma manager
# associated with the dead local scheduler, if any.
for manager in manager_ids:
@ -175,7 +185,8 @@ class Monitor(object):
# task as lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
key, "RAY.TASK_TABLE_UPDATE",
hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
@ -238,7 +249,7 @@ class Monitor(object):
log.debug("Subscribed to {}, data was {}".format(channel, data))
self.subscribed[channel] = True
def db_client_notification_handler(self, channel, data):
def db_client_notification_handler(self, unused_channel, data):
"""Handle a notification from the db_client table from Redis.
This handler processes notifications from the db_client table.
@ -247,9 +258,8 @@ class Monitor(object):
the associated state in the state tables should be handled by the
caller.
"""
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data,
0))
notification_object = (SubscribeToDBClientTableReply.
GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
@ -271,7 +281,7 @@ class Monitor(object):
# already dead.
del self.live_plasma_managers[db_client_id]
def plasma_manager_heartbeat_handler(self, channel, data):
def plasma_manager_heartbeat_handler(self, unused_channel, data):
"""Handle a plasma manager heartbeat from Redis.
This resets the number of heartbeats that we've missed from this plasma
@ -283,7 +293,134 @@ class Monitor(object):
# manager.
self.live_plasma_managers[db_client_id] = 0
def driver_removed_handler(self, channel, data):
def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
"""Collect IDs of control-state entries for a driver from a shard.
Args:
driver_id: The ID of the driver.
redis_shard_index: The index of the Redis shard to query.
Returns:
Lists of IDs: (returned_object_ids, task_ids, put_objects). The
first two are relevant to the driver and are safe to delete.
The last contains all "put" objects in this redis shard; each
element is an (object_id, corresponding task_id) pair.
"""
# TODO(zongheng): consider adding save & restore functionalities.
redis = self.state.redis_clients[redis_shard_index]
task_table_infos = {} # task id -> TaskInfo messages
# Scan the task table & filter to get the list of tasks belong to this
# driver. Use a cursor in order not to block the redis shards.
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
entry = redis.hgetall(key)
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
if driver_id != task_info.DriverId():
# Ignore tasks that aren't from this driver.
continue
task_table_infos[task_info.TaskId()] = task_info
# Get the list of objects returned by these tasks. Note these might
# not belong to this redis shard.
returned_object_ids = []
for task_info in task_table_infos.values():
returned_object_ids.extend([
task_info.Returns(i) for i in range(task_info.ReturnsLength())
])
# Also record all the ray.put()'d objects.
put_objects = []
for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
entry = redis.hgetall(key)
if entry[b"is_put"] == "0":
continue
object_id = key.split(OBJECT_INFO_PREFIX)[1]
task_id = entry[b"task"]
put_objects.append((object_id, task_id))
return returned_object_ids, task_table_infos.keys(), put_objects
def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
redis = self.state.redis_clients[shard_index]
# Clean up (in the future, save) entries for non-empty objects.
object_ids_locs = set()
object_ids_infos = set()
for object_id in object_ids:
# OL.
obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
if obj_loc:
object_ids_locs.add(object_id)
# OI.
obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
if obj_info:
object_ids_infos.add(object_id)
# Form the redis keys to delete.
keys = [TASK_TABLE_PREFIX + k for k in task_ids]
keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])
if not keys:
return
# Remove with best effort.
num_deleted = redis.delete(*keys)
log.info(
"Removed {} dead redis entries of the driver from redis shard {}.".
format(num_deleted, shard_index))
if num_deleted != len(keys):
log.warning(
"Failed to remove {} relevant redis entries"
" from redis shard {}.".format(len(keys) - num_deleted))
def _clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from all redis shards.
Specifically, removes control-state entries of:
* all objects (OI and OL entries) created by `ray.put()` from the
driver
* all tasks belonging to the driver.
"""
# TODO(zongheng): handle function_table, client_table, log_files --
# these are in the metadata redis server, not in the shards.
driver_object_ids = []
driver_task_ids = []
all_put_objects = []
# Collect relevant ids.
# TODO(zongheng): consider parallelizing this loop.
for shard_index in range(len(self.state.redis_clients)):
returned_object_ids, task_ids, put_objects = \
self._entries_for_driver_in_shard(driver_id, shard_index)
driver_object_ids.extend(returned_object_ids)
driver_task_ids.extend(task_ids)
all_put_objects.extend(put_objects)
# For the put objects, keep those from relevant tasks.
driver_task_ids_set = set(driver_task_ids)
for object_id, task_id in all_put_objects:
if task_id in driver_task_ids_set:
driver_object_ids.append(object_id)
# Partition IDs and distribute to shards.
object_ids_per_shard = defaultdict(list)
task_ids_per_shard = defaultdict(list)
def ToShardIndex(index):
return binary_to_object_id(index).redis_shard_hash() % len(
self.state.redis_clients)
for object_id in driver_object_ids:
object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
for task_id in driver_task_ids:
task_ids_per_shard[ToShardIndex(task_id)].append(task_id)
# TODO(zongheng): consider parallelizing this loop.
for shard_index in range(len(self.state.redis_clients)):
self._clean_up_entries_from_shard(
object_ids_per_shard[shard_index],
task_ids_per_shard[shard_index], shard_index)
def driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
This releases any GPU resources that were reserved for that driver in
@ -291,8 +428,8 @@ class Monitor(object):
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed."
.format(binary_to_hex(driver_id)))
log.info(
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
# Get a list of the local schedulers.
client_table = ray.global_state.client_table()
@ -302,6 +439,8 @@ class Monitor(object):
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
self._clean_up_entries_for_driver(driver_id)
# Release any GPU resources that have been reserved for this driver in
# Redis.
for local_scheduler in local_schedulers:
@ -321,8 +460,8 @@ class Monitor(object):
result = pipe.hget(local_scheduler_id,
"gpus_in_use")
gpus_in_use = (dict() if result is None
else json.loads(result))
gpus_in_use = (dict() if result is None else
json.loads(result))
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
@ -345,9 +484,9 @@ class Monitor(object):
continue
log.info("Driver {} is returning GPU IDs {} to local "
"scheduler {}.".format(binary_to_hex(driver_id),
num_gpus_returned,
local_scheduler_id))
"scheduler {}.".format(
binary_to_hex(driver_id), num_gpus_returned,
local_scheduler_id))
def process_messages(self):
"""Process all messages ready in the subscription channels.
@ -371,22 +510,23 @@ class Monitor(object):
# to an initial subscription request.
message_handler = self.subscribe_handler
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
assert(self.subscribed[channel])
assert self.subscribed[channel]
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == DB_CLIENT_TABLE_NAME:
assert(self.subscribed[channel])
assert self.subscribed[channel]
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
assert(self.subscribed[channel])
assert self.subscribed[channel]
# The message was a notification that a driver was removed.
log.info("message-handler: driver_removed_handler")
message_handler = self.driver_removed_handler
else:
raise Exception("This code should be unreachable.")
# Call the handler.
assert(message_handler is not None)
assert (message_handler is not None)
message_handler(channel, data)
def run(self):
@ -439,8 +579,8 @@ class Monitor(object):
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if ((self.live_plasma_managers
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
if ((self.live_plasma_managers[plasma_manager_id]) >=
NUM_HEARTBEATS_TIMEOUT):
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose
# heartbeats we're tracking.
@ -465,8 +605,11 @@ class Monitor(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="the address to use for Redis")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)

View file

@ -8,22 +8,25 @@
#include "common_protocol.h"
/**
* Various tables are maintained in redis:
*
* == OBJECT TABLE ==
*
* This consists of two parts:
* - The object location table, indexed by OL:object_id, which is the set of
* plasma manager indices that have access to the object.
* - The object info table, indexed by OI:object_id, which is a hashmap with key
* "hash" for the hash of the object and key "data_size" for the size of the
* object in bytes.
*
* == TASK TABLE ==
*
* TODO(pcm): Fill this out.
*/
// Various tables are maintained in redis:
//
// == OBJECT TABLE ==
//
// This consists of two parts:
// - The object location table, indexed by OL:object_id, which is the set of
// plasma manager indices that have access to the object.
// (In redis this is represented by a zset (sorted set).)
//
// - The object info table, indexed by OI:object_id, which is a hashmap of:
// "hash" -> the hash of the object,
// "data_size" -> the size of the object in bytes,
// "task" -> the task ID that generated this object.
// "is_put" -> 0 or 1.
//
// == TASK TABLE ==
//
// TODO(pcm): Fill this out.
//
#define OBJECT_INFO_PREFIX "OI:"
#define OBJECT_LOCATION_PREFIX "OL:"

93
test/monitor_test.py Normal file
View file

@ -0,0 +1,93 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import subprocess
import time
import unittest
import ray
class MonitorTest(unittest.TestCase):
def _testCleanupOnDriverExit(self, num_redis_shards):
stdout = subprocess.check_output([
"ray",
"start",
"--head",
"--num-redis-shards",
str(num_redis_shards),
]).decode("ascii")
lines = [m.strip() for m in stdout.split("\n")]
init_cmd = [m for m in lines if m.startswith("ray.init")]
self.assertEqual(1, len(init_cmd))
redis_address = init_cmd[0].split("redis_address=\"")[-1][:-2]
def StateSummary():
obj_tbl_len = len(ray.global_state.object_table())
task_tbl_len = len(ray.global_state.task_table())
func_tbl_len = len(ray.global_state.function_table())
return obj_tbl_len, task_tbl_len, func_tbl_len
def Driver(success):
success.value = True
# Start driver.
ray.init(redis_address=redis_address)
summary_start = StateSummary()
if (0, 1) != summary_start[:2]:
success.value = False
# Two new objects.
ray.get(ray.put(1111))
ray.get(ray.put(1111))
if (2, 1, summary_start[2]) != StateSummary():
success.value = False
@ray.remote
def f():
ray.put(1111) # Yet another object.
return 1111 # A returned object as well.
# 1 new function.
if (2, 1, summary_start[2] + 1) != StateSummary():
success.value = False
ray.get(f.remote())
if (4, 2, summary_start[2] + 1) != StateSummary():
success.value = False
ray.worker.cleanup()
success = multiprocessing.Value('b', False)
driver = multiprocessing.Process(target=Driver, args=(success, ))
driver.start()
# Wait for client to exit.
driver.join()
time.sleep(5)
# Just make sure Driver() is run and succeeded. Note(rkn), if the below
# assertion starts failing, then the issue may be that the summary
# values computed in the Driver function are being updated slowly and
# so the call to StateSummary() is getting outdated values. This could
# be fixed by looping until StateSummary() returns the desired values.
self.assertTrue(success.value)
# Check that objects, tasks, and functions are cleaned up.
ray.init(redis_address=redis_address)
# The assertion below can fail if the monitor is too slow to clean up
# the global state.
self.assertEqual((0, 1), StateSummary()[:2])
ray.worker.cleanup()
subprocess.Popen(["ray", "stop"]).wait()
def testCleanupOnDriverExitSingleRedisShard(self):
self._testCleanupOnDriverExit(num_redis_shards=1)
def testCleanupOnDriverExitManyRedisShards(self):
self._testCleanupOnDriverExit(num_redis_shards=5)
self._testCleanupOnDriverExit(num_redis_shards=31)
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -1,18 +1,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
from collections import defaultdict, namedtuple
import numpy as np
import os
import ray
import re
import shutil
import string
import sys
import time
import unittest
from collections import defaultdict, namedtuple
import numpy as np
import ray
import ray.test.test_functions as test_functions
import ray.test.test_utils
@ -21,11 +20,11 @@ if sys.version_info >= (3, 0):
def assert_equal(obj1, obj2):
module_numpy = (type(obj1).__module__ == np.__name__ or
type(obj2).__module__ == np.__name__)
module_numpy = (type(obj1).__module__ == np.__name__
or type(obj2).__module__ == np.__name__)
if module_numpy:
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or
(hasattr(obj2, "shape") and obj2.shape == ()))
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ())
or (hasattr(obj2, "shape") and obj2.shape == ()))
if empty_shape:
# This is a special case because currently np.testing.assert_equal
# fails because we do not properly handle different numerical
@ -36,13 +35,11 @@ def assert_equal(obj1, obj2):
np.testing.assert_equal(obj1, obj2)
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
special_keys = ["_pytype_"]
assert (set(list(obj1.__dict__.keys()) + special_keys) ==
set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
"and {} are "
"different."
.format(
obj1,
obj2))
assert (set(list(obj1.__dict__.keys()) + special_keys) == set(
list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
"and {} are "
"different.".format(
obj1, obj2))
for key in obj1.__dict__.keys():
if key not in special_keys:
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
@ -52,49 +49,76 @@ def assert_equal(obj1, obj2):
assert_equal(obj1[key], obj2[key])
elif type(obj1) is list or type(obj2) is list:
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
"different lengths."
.format(obj1, obj2))
"different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif type(obj1) is tuple or type(obj2) is tuple:
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with "
"different lengths."
.format(obj1, obj2))
"different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif (ray.serialization.is_named_tuple(type(obj1)) or
ray.serialization.is_named_tuple(type(obj2))):
elif (ray.serialization.is_named_tuple(type(obj1))
or ray.serialization.is_named_tuple(type(obj2))):
assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples "
"with different lengths."
.format(obj1, obj2))
"with different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
else:
assert obj1 == obj2, "Objects {} and {} are different.".format(obj1,
obj2)
assert obj1 == obj2, "Objects {} and {} are different.".format(
obj1, obj2)
if sys.version_info >= (3, 0):
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821
PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999,
[1 << 100, [1 << 100]], "a", string.printable, "\u262F",
u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True,
False, [], (), {}, np.int8(3), np.int32(4), np.int64(5),
np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.9),
np.float64(1.9), np.zeros([100, 100]),
np.random.normal(size=[100, 100]), np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)] + long_extras
long_extras = [
long(0), # noqa: E501,F821
np.array([
["hi", u"hi"],
[1.3, long(1)] # noqa: E501,F821
])
]
PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a",
string.printable, "\u262F", u"hello world", u"\xff\xfe\x9c\x001\x000\x00",
None, True, False, [], (), {},
np.int8(3),
np.int32(4),
np.int64(5),
np.uint8(3),
np.uint32(4),
np.uint64(5),
np.float32(1.9),
np.float64(1.9),
np.zeros([100, 100]),
np.random.normal(size=[100, 100]),
np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)
] + long_extras
COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
{"obj{}".format(i): np.random.normal(size=[100, 100])
for i in range(10)},
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
# (): {(): {}}}}}}}}}}}}},
((((((((((),),),),),),),),),),
{"a": {"b": {"c": {"d": {}}}}}]
(
(((((((((), ), ), ), ), ), ), ), ), ),
{
"a": {
"b": {
"c": {
"d": {}
}
}
}
}
]
class Foo(object):
@ -141,21 +165,32 @@ Point = namedtuple("Point", ["x", "y"])
NamedTupleExample = namedtuple("Example",
"field1, field2, field3, field4, field5")
CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
Foo(), Bar(), Baz(), # Qux(), SubQux(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])]
CUSTOM_OBJECTS = [
Exception("Test object."),
CustomError(),
Point(11, y=22),
Foo(),
Bar(),
Baz(), # Qux(), SubQux(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])
]
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS]
TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS]
TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS]
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
# otherwise this seems to fail on Mac OS X on Travis.
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS
if (obj.__hash__ is not None and
type(obj).__module__ != "numpy")] +
[{0: obj} for obj in BASE_OBJECTS] +
[{Foo(123): Foo(456)}])
DICT_OBJECTS = (
[{
obj: obj
} for obj in PRIMITIVE_OBJECTS
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
0:
obj
} for obj in BASE_OBJECTS] + [{
Foo(123): Foo(456)
}])
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
@ -171,7 +206,6 @@ except AttributeError:
class SerializationTest(unittest.TestCase):
def testRecursiveObjects(self):
ray.init(num_workers=0)
@ -253,15 +287,15 @@ class SerializationTest(unittest.TestCase):
class WorkerTest(unittest.TestCase):
def testPythonWorkers(self):
# Test the codepath for starting workers from the Python script,
# instead of the local scheduler. This codepath is for debugging
# purposes only.
num_workers = 4
ray.worker._init(num_workers=num_workers,
start_workers_from_local_scheduler=False,
start_ray_local=True)
ray.worker._init(
num_workers=num_workers,
start_workers_from_local_scheduler=False,
start_ray_local=True)
@ray.remote
def f(x):
@ -275,13 +309,13 @@ class WorkerTest(unittest.TestCase):
ray.init(num_workers=0)
for i in range(100):
value_before = i * 10 ** 6
value_before = i * 10**6
objectid = ray.put(value_before)
value_after = ray.get(objectid)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = i * 10 ** 6 * 1.0
value_before = i * 10**6 * 1.0
objectid = ray.put(value_before)
value_after = ray.get(objectid)
self.assertEqual(value_before, value_after)
@ -302,7 +336,6 @@ class WorkerTest(unittest.TestCase):
class APITest(unittest.TestCase):
def init_ray(self, kwargs=None):
if kwargs is None:
kwargs = {}
@ -318,6 +351,7 @@ class APITest(unittest.TestCase):
# throws an exception.
class TempClass(object):
pass
ray.get(ray.put(TempClass()))
# Note that the below actually returns a dictionary and not a
@ -525,14 +559,14 @@ class APITest(unittest.TestCase):
return x, y, args
self.assertEqual(ray.get(f1.remote()), ())
self.assertEqual(ray.get(f1.remote(1)), (1,))
self.assertEqual(ray.get(f1.remote(1)), (1, ))
self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3))
with self.assertRaises(Exception):
f2.remote()
with self.assertRaises(Exception):
f2.remote(1)
self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ()))
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,)))
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3, )))
self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4)))
def testNoArgs(self):
@ -548,12 +582,14 @@ class APITest(unittest.TestCase):
@ray.remote
def f(x):
return x + 1
self.assertEqual(ray.get(f.remote(0)), 1)
# Test that we can redefine the remote function.
@ray.remote
def f(x):
return x + 10
while True:
val = ray.get(f.remote(0))
self.assertTrue(val in [1, 10])
@ -563,23 +599,29 @@ class APITest(unittest.TestCase):
print("Still using old definition of f, trying again.")
# Test that we can close over plain old data.
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60,
{"a": np.zeros(3)}]
data = [
np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {
"a": np.zeros(3)
}
]
@ray.remote
def g():
return data
ray.get(g.remote())
# Test that we can close over modules.
@ray.remote
def h():
return np.zeros([3, 5])
assert_equal(ray.get(h.remote()), np.zeros([3, 5]))
@ray.remote
def j():
return time.time()
ray.get(j.remote())
# Test that we can define remote functions that call other remote
@ -595,6 +637,7 @@ class APITest(unittest.TestCase):
@ray.remote
def m(x):
return ray.get(l.remote(x))
self.assertEqual(ray.get(k.remote(1)), 2)
self.assertEqual(ray.get(l.remote(1)), 2)
self.assertEqual(ray.get(m.remote(1)), 2)
@ -618,8 +661,12 @@ class APITest(unittest.TestCase):
time.sleep(delay)
return 1
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5),
f.remote(0.5)]
objectids = [
f.remote(1.0),
f.remote(0.5),
f.remote(0.5),
f.remote(0.5)
]
ready_ids, remaining_ids = ray.wait(objectids)
self.assertEqual(len(ready_ids), 1)
self.assertEqual(len(remaining_ids), 3)
@ -627,17 +674,25 @@ class APITest(unittest.TestCase):
self.assertEqual(set(ready_ids), set(objectids))
self.assertEqual(remaining_ids, [])
objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5),
f.remote(0.5)]
objectids = [
f.remote(0.5),
f.remote(0.5),
f.remote(0.5),
f.remote(0.5)
]
start_time = time.time()
ready_ids, remaining_ids = ray.wait(objectids, timeout=1750,
num_returns=4)
ready_ids, remaining_ids = ray.wait(
objectids, timeout=1750, num_returns=4)
self.assertLess(time.time() - start_time, 2)
self.assertEqual(len(ready_ids), 3)
self.assertEqual(len(remaining_ids), 1)
ray.wait(objectids)
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5),
f.remote(0.5)]
objectids = [
f.remote(1.0),
f.remote(0.5),
f.remote(0.5),
f.remote(0.5)
]
start_time = time.time()
ready_ids, remaining_ids = ray.wait(objectids, timeout=5000)
self.assertTrue(time.time() - start_time < 5)
@ -684,18 +739,22 @@ class APITest(unittest.TestCase):
# is connected.
def f(worker_info):
sys.path.append(1)
ray.worker.global_worker.run_function_on_all_workers(f)
def f(worker_info):
sys.path.append(2)
ray.worker.global_worker.run_function_on_all_workers(f)
def g(worker_info):
sys.path.append(3)
ray.worker.global_worker.run_function_on_all_workers(g)
def f(worker_info):
sys.path.append(4)
ray.worker.global_worker.run_function_on_all_workers(f)
self.init_ray()
@ -716,6 +775,7 @@ class APITest(unittest.TestCase):
sys.path.pop()
sys.path.pop()
sys.path.pop()
ray.worker.global_worker.run_function_on_all_workers(f)
def testRunningFunctionOnAllWorkers(self):
@ -723,15 +783,18 @@ class APITest(unittest.TestCase):
def f(worker_info):
sys.path.append("fake_directory")
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote
def get_path1():
return sys.path
self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1])
def f(worker_info):
sys.path.pop(-1)
ray.worker.global_worker.run_function_on_all_workers(f)
# Create a second remote function to guarantee that when we call
@ -740,6 +803,7 @@ class APITest(unittest.TestCase):
@ray.remote
def get_path2():
return sys.path
self.assertTrue("fake_directory" not in ray.get(get_path2.remote()))
def testLoggingAPI(self):
@ -751,8 +815,8 @@ class APITest(unittest.TestCase):
keys = ray.worker.global_worker.redis_client.keys("event_log:*")
res = []
for key in keys:
res.extend(ray.worker.global_worker.redis_client.zrange(key, 0,
-1))
res.extend(
ray.worker.global_worker.redis_client.zrange(key, 0, -1))
return res
def wait_for_num_events(num_events, timeout=10):
@ -806,26 +870,31 @@ class APITest(unittest.TestCase):
@ray.remote
def f():
return 1
results1 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 2
results2 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 3
results3 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 4
results4 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 5
results5 = [f.remote() for _ in range(num_calls)]
self.assertEqual(ray.get(results1), num_calls * [1])
@ -870,7 +939,6 @@ class APITest(unittest.TestCase):
class APITestSharded(APITest):
def init_ray(self, kwargs=None):
if kwargs is None:
kwargs = {}
@ -881,7 +949,6 @@ class APITestSharded(APITest):
class PythonModeTest(unittest.TestCase):
def testPythonMode(self):
reload(test_functions)
ray.init(driver_mode=ray.PYTHON_MODE)
@ -889,6 +956,7 @@ class PythonModeTest(unittest.TestCase):
@ray.remote
def f():
return np.ones([3, 4, 5])
xref = f.remote()
# Remote functions should return by value.
assert_equal(xref, np.ones([3, 4, 5]))
@ -911,8 +979,8 @@ class PythonModeTest(unittest.TestCase):
# first list and the remaining values as the second list
num_returns = 5
object_ids = [ray.put(i) for i in range(20)]
ready, remaining = ray.wait(object_ids, num_returns=num_returns,
timeout=None)
ready, remaining = ray.wait(
object_ids, num_returns=num_returns, timeout=None)
assert_equal(ready, object_ids[:num_returns])
assert_equal(remaining, object_ids[num_returns:])
@ -949,7 +1017,6 @@ class PythonModeTest(unittest.TestCase):
class UtilsTest(unittest.TestCase):
def testCopyingDirectory(self):
# The functionality being tested here is really multi-node
# functionality, but this test just uses a single node.
@ -999,7 +1066,6 @@ class UtilsTest(unittest.TestCase):
class ResourcesTest(unittest.TestCase):
def testResourceConstraints(self):
num_workers = 20
ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2)
@ -1012,9 +1078,13 @@ class ResourcesTest(unittest.TestCase):
def get_worker_id():
time.sleep(1)
return sys.path[-1]
while True:
if len(set(ray.get([get_worker_id.remote()
for _ in range(num_workers)]))) == num_workers:
if len(
set(
ray.get([
get_worker_id.remote() for _ in range(num_workers)
]))) == num_workers:
break
time_buffer = 0.3
@ -1088,9 +1158,13 @@ class ResourcesTest(unittest.TestCase):
def get_worker_id():
time.sleep(1)
return sys.path[-1]
while True:
if len(set(ray.get([get_worker_id.remote()
for _ in range(num_workers)]))) == num_workers:
if len(
set(
ray.get([
get_worker_id.remote() for _ in range(num_workers)
]))) == num_workers:
break
@ray.remote(num_cpus=1, num_gpus=9)
@ -1192,7 +1266,7 @@ class ResourcesTest(unittest.TestCase):
list_of_ids = ray.get([f1.remote() for _ in range(10)])
set_of_ids = set([tuple(gpu_ids) for gpu_ids in list_of_ids])
self.assertEqual(set_of_ids, set([(i,) for i in range(10)]))
self.assertEqual(set_of_ids, set([(i, ) for i in range(10)]))
list_of_ids = ray.get([f2.remote(), f4.remote(), f4.remote()])
all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids]
@ -1218,11 +1292,12 @@ class ResourcesTest(unittest.TestCase):
# This test will define a bunch of tasks that can only be assigned to
# specific local schedulers, and we will check that they are assigned
# to the correct local schedulers.
address_info = ray.worker._init(start_ray_local=True,
num_local_schedulers=3,
num_workers=1,
num_cpus=[100, 5, 10],
num_gpus=[0, 5, 1])
address_info = ray.worker._init(
start_ray_local=True,
num_local_schedulers=3,
num_workers=1,
num_cpus=[100, 5, 10],
num_gpus=[0, 5, 1])
# Define a bunch of remote functions that all return the socket name of
# the plasma store. Since there is a one-to-one correspondence between
@ -1284,8 +1359,10 @@ class ResourcesTest(unittest.TestCase):
results.append(run_on_0_2.remote())
return names, results
store_names = [object_store_address.name for object_store_address
in address_info["object_store_addresses"]]
store_names = [
object_store_address.name
for object_store_address in address_info["object_store_addresses"]
]
def validate_names_and_results(names, results):
for name, result in zip(names, ray.get(results)):
@ -1296,8 +1373,9 @@ class ResourcesTest(unittest.TestCase):
elif name == "run_on_2":
self.assertIn(result, [store_names[2]])
elif name == "run_on_0_1_2":
self.assertIn(result, [store_names[0], store_names[1],
store_names[2]])
self.assertIn(result, [
store_names[0], store_names[1], store_names[2]
])
elif name == "run_on_1_2":
self.assertIn(result, [store_names[1], store_names[2]])
elif name == "run_on_0_2":
@ -1327,8 +1405,11 @@ class ResourcesTest(unittest.TestCase):
ray.worker.cleanup()
def testCustomResources(self):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_cpus=3, num_custom_resource=[0, 1])
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_cpus=3,
num_custom_resource=[0, 1])
@ray.remote
def f():
@ -1373,13 +1454,12 @@ class ResourcesTest(unittest.TestCase):
ray.get(ray.remote(num_custom_resource=2)(f).remote())
ray.get(ray.remote(num_custom_resource=4)(f).remote())
ray.get(ray.remote(num_custom_resource=8)(f).remote())
ray.get(ray.remote(num_custom_resource=(10 ** 10))(f).remote())
ray.get(ray.remote(num_custom_resource=(10**10))(f).remote())
ray.worker.cleanup()
class WorkerPoolTests(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -1450,19 +1530,22 @@ class WorkerPoolTests(unittest.TestCase):
class SchedulingAlgorithm(unittest.TestCase):
def attempt_to_load_balance(self, remote_function, args, total_tasks,
num_local_schedulers, minimum_count,
def attempt_to_load_balance(self,
remote_function,
args,
total_tasks,
num_local_schedulers,
minimum_count,
num_attempts=20):
attempts = 0
while attempts < num_attempts:
locations = ray.get([remote_function.remote(*args)
for _ in range(total_tasks)])
locations = ray.get(
[remote_function.remote(*args) for _ in range(total_tasks)])
names = set(locations)
counts = [locations.count(name) for name in names]
print("Counts are {}.".format(counts))
if (len(names) == num_local_schedulers and
all([count >= minimum_count for count in counts])):
if (len(names) == num_local_schedulers
and all([count >= minimum_count for count in counts])):
break
attempts += 1
self.assertLess(attempts, num_attempts)
@ -1472,9 +1555,10 @@ class SchedulingAlgorithm(unittest.TestCase):
# schedulers in a roughly equal manner.
num_local_schedulers = 3
num_cpus = 7
ray.worker._init(start_ray_local=True,
num_local_schedulers=num_local_schedulers,
num_cpus=num_cpus)
ray.worker._init(
start_ray_local=True,
num_local_schedulers=num_local_schedulers,
num_cpus=num_cpus)
@ray.remote
def f():
@ -1492,8 +1576,10 @@ class SchedulingAlgorithm(unittest.TestCase):
# dependencies.
num_workers = 3
num_local_schedulers = 3
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
ray.worker._init(
start_ray_local=True,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
@ray.remote
def f(x):
@ -1528,7 +1614,6 @@ def wait_for_num_objects(num_objects, timeout=10):
class GlobalStateAPI(unittest.TestCase):
def testGlobalStateAPI(self):
with self.assertRaises(Exception):
ray.global_state.object_table()
@ -1572,15 +1657,16 @@ class GlobalStateAPI(unittest.TestCase):
driver_id)
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"],
ID_SIZE * "ff")
self.assertEqual((task_table[driver_task_id]["TaskSpec"]
["ReturnObjectIDs"]),
[])
self.assertEqual(
(task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), [])
client_table = ray.global_state.client_table()
node_ip_address = ray.worker.global_worker.node_ip_address
self.assertEqual(len(client_table[node_ip_address]), 3)
manager_client = [c for c in client_table[node_ip_address]
if c["ClientType"] == "plasma_manager"][0]
manager_client = [
c for c in client_table[node_ip_address]
if c["ClientType"] == "plasma_manager"
][0]
@ray.remote
def f(*xs):
@ -1624,8 +1710,8 @@ class GlobalStateAPI(unittest.TestCase):
while time.time() - start_time < timeout:
object_table = ray.global_state.object_table()
tables_ready = (
object_table[x_id]["ManagerIDs"] is not None and
object_table[result_id]["ManagerIDs"] is not None)
object_table[x_id]["ManagerIDs"] is not None
and object_table[result_id]["ManagerIDs"] is not None)
if tables_ready:
return
time.sleep(0.1)
@ -1701,8 +1787,8 @@ class GlobalStateAPI(unittest.TestCase):
while time.time() - start_time < 10:
profiles = ray.global_state.task_profiles(
100, start=0, end=time.time())
limited_profiles = ray.global_state.task_profiles(1, start=0,
end=time.time())
limited_profiles = ray.global_state.task_profiles(
1, start=0, end=time.time())
if len(profiles) == num_calls and len(limited_profiles) == 1:
break
time.sleep(0.1)
@ -1722,8 +1808,10 @@ class GlobalStateAPI(unittest.TestCase):
def testWorkers(self):
num_workers = 3
ray.init(redirect_output=True, num_cpus=num_workers,
num_workers=num_workers)
ray.init(
redirect_output=True,
num_cpus=num_workers,
num_workers=num_workers)
@ray.remote
def f():