Lint Python files with Yapf (#1872)

This commit is contained in:
Philipp Moritz 2018-04-11 10:11:35 -07:00 committed by Robert Nishihara
parent a3ddde398c
commit 74162d1492
97 changed files with 3927 additions and 3139 deletions

View file

@ -42,7 +42,7 @@ blank_line_before_nested_class_or_def=False
# 'key1': 'value1',
# 'key2': 'value2',
# })
coalesce_brackets=False
coalesce_brackets=True
# The column limit.
column_limit=79
@ -90,7 +90,7 @@ i18n_function_call=
# 'key2': value1 +
# value2,
# }
indent_dictionary_value=False
indent_dictionary_value=True
# The number of columns to use for indentation.
indent_width=4
@ -187,4 +187,3 @@ split_penalty_logical_operator=300
# Use the Tab character for indentation.
use_tabs=False

View file

@ -38,10 +38,12 @@ matrix:
- export PATH="$HOME/miniconda/bin:$PATH"
- cd doc
- pip install -q -r requirements-doc.txt
- pip install yapf
- sphinx-build -W -b html -d _build/doctrees source _build/html
- cd ..
# Run Python linting.
- flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/
- .travis/yapf.sh
- os: linux
dist: trusty

27
.travis/yapf.sh Executable file
View file

@ -0,0 +1,27 @@
#!/usr/bin/env bash
# Cause the script to exit if a single command fails
set -e
ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd)
pushd $ROOT_DIR/../test
find . -name '*.py' -type f -exec yapf --style=pep8 -i -r {} \;
popd
pushd $ROOT_DIR/../python
find . -name '*.py' -type f -not -path './ray/dataframe/*' -not -path './ray/rllib/*' -not -path './ray/cloudpickle/*' -exec yapf --style=pep8 -i -r {} \;
popd
CHANGED_FILES=(`git diff --name-only`)
if [ "$CHANGED_FILES" ]; then
echo 'Reformatted staged files. Please review and stage the changes.'
echo
echo 'Files updated:'
for file in ${CHANGED_FILES[@]}; do
echo " $file"
done
exit 1
else
exit 0
fi

View file

@ -12,8 +12,8 @@ if "pyarrow" in sys.modules:
# Add the directory containing pyarrow to the Python path so that we find the
# pyarrow version packaged with ray and not a pre-existing pyarrow.
pyarrow_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"pyarrow_files")
pyarrow_path = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "pyarrow_files")
sys.path.insert(0, pyarrow_path)
# See https://github.com/ray-project/ray/issues/131.
@ -27,29 +27,29 @@ If you are using Anaconda, try fixing this problem by running:
try:
import pyarrow # noqa: F401
except ImportError as e:
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
("libstdc++" in e.msg or "CXX" in e.msg))):
if ((hasattr(e, "msg") and isinstance(e.msg, str)
and ("libstdc++" in e.msg or "CXX" in e.msg))):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
elif (hasattr(e, "message") and isinstance(e.message, str)
and ("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
condition = (hasattr(e, "args") and isinstance(e.args, tuple)
and len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
e.args = (e.args[0] + helpful_message, )
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
e.args = (e.args, )
e.args += (helpful_message, )
raise
from ray.local_scheduler import _config # noqa: E402
from ray.worker import (error_info, init, connect, disconnect,
get, put, wait, remote, log_event, log_span,
flush_log, get_gpu_ids, get_webui_url,
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
remote, log_event, log_span, flush_log, get_gpu_ids,
get_webui_url,
register_custom_serializer) # noqa: E402
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
SILENT_MODE) # noqa: E402
@ -63,11 +63,13 @@ from ray.actor import method # noqa: E402
# Fix this.
__version__ = "0.4.0"
__all__ = ["error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
"global_state", "_config", "__version__"]
__all__ = [
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state",
"_config", "__version__"
]
import ctypes # noqa: E402
# Windows only

View file

@ -121,16 +121,17 @@ def save_and_log_checkpoint(worker, actor):
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__})
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__
})
def restore_and_log_checkpoint(worker, actor):
@ -144,8 +145,7 @@ def restore_and_log_checkpoint(worker, actor):
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
@ -154,8 +154,8 @@ def restore_and_log_checkpoint(worker, actor):
driver_id=worker.task_driver_id.id(),
data={
"actor_class": actor.__class__.__name__,
"function_name":
actor.__ray_checkpoint_restore__.__name__})
"function_name": actor.__ray_checkpoint_restore__.__name__
})
return checkpoint_resumed
@ -197,15 +197,15 @@ def make_actor_method_executor(worker, method_name, method, actor_imported):
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported and
worker.actor_checkpoint_interval > 0)
checkpointing_on = (actor_imported
and worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and the
# method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on and
(worker.actor_task_counter %
worker.actor_checkpoint_interval == 0 and
method_name != "__ray_checkpoint__"))
save_checkpoint = (
checkpointing_on and
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
@ -238,14 +238,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
worker: The worker to use.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name,
module, pickled_class, checkpoint_interval,
actor_method_names,
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval, actor_method_names,
actor_method_num_return_vals) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "checkpoint_interval",
"actor_method_names",
"actor_method_num_return_vals"])
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names",
"actor_method_num_return_vals"
])
actor_name = class_name.decode("ascii")
module = module.decode("ascii")
@ -259,12 +259,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# error messages and to prevent the driver from hanging).
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
# Register the actor method signatures.
register_actor_signatures(worker, driver_id, class_id, class_name,
actor_method_names, actor_method_num_return_vals)
@ -272,10 +274,11 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
for actor_method_name in actor_method_names:
function_id = compute_actor_method_function_id(class_name,
actor_method_name).id()
temporary_executor = make_actor_method_executor(worker,
actor_method_name,
temporary_actor_method,
actor_imported=False)
temporary_executor = make_actor_method_executor(
worker,
actor_method_name,
temporary_actor_method,
actor_imported=False)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_executor)
worker.num_task_executions[driver_id][function_id] = 0
@ -288,9 +291,12 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(worker.redis_client, "register_actor_signatures",
traceback_str, driver_id,
data={"actor_id": actor_id_str})
push_error_to_driver(
worker.redis_client,
"register_actor_signatures",
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker exit
# here. However, currently that would lead to hanging if someone calls
# ray.get on a method invoked on the actor.
@ -298,16 +304,17 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x) or
is_cython(x))))
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
for actor_method_name, actor_method in actor_methods:
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
executor = make_actor_method_executor(worker, actor_method_name,
actor_method,
actor_imported=True)
executor = make_actor_method_executor(
worker, actor_method_name, actor_method, actor_imported=True)
worker.functions[driver_id][function_id] = (actor_method_name,
executor)
# We do not set worker.function_properties[driver_id][function_id]
@ -315,7 +322,10 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# for the actor.
def register_actor_signatures(worker, driver_id, class_id, class_name,
def register_actor_signatures(worker,
driver_id,
class_id,
class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources=None,
@ -346,18 +356,20 @@ def register_actor_signatures(worker, driver_id, class_id, class_name,
# The extra return value is an actor dummy object.
# In the cases where actor_method_cpus is None, that value should
# never be used.
FunctionProperties(num_return_vals=num_return_vals + 1,
resources={"CPU": actor_method_cpus},
max_calls=0))
FunctionProperties(
num_return_vals=num_return_vals + 1,
resources={"CPU": actor_method_cpus},
max_calls=0))
if actor_creation_resources is not None:
# Also register the actor creation task.
function_id = compute_actor_creation_function_id(class_id)
worker.function_properties[driver_id][function_id.id()] = (
# The extra return value is an actor dummy object.
FunctionProperties(num_return_vals=0 + 1,
resources=actor_creation_resources,
max_calls=0))
FunctionProperties(
num_return_vals=0 + 1,
resources=actor_creation_resources,
max_calls=0))
def publish_actor_class_to_key(key, actor_class_info, worker):
@ -380,8 +392,8 @@ def publish_actor_class_to_key(key, actor_class_info, worker):
def export_actor_class(class_id, Class, actor_method_names,
actor_method_num_return_vals,
checkpoint_interval, worker):
actor_method_num_return_vals, checkpoint_interval,
worker):
key = b"ActorClass:" + class_id
actor_class_info = {
"class_name": Class.__name__,
@ -389,8 +401,9 @@ def export_actor_class(class_id, Class, actor_method_names,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names)),
"actor_method_num_return_vals": json.dumps(
actor_method_num_return_vals)}
"actor_method_num_return_vals":
json.dumps(actor_method_num_return_vals)
}
if worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we must
@ -433,7 +446,11 @@ def export_actor(actor_id, class_id, class_name, actor_method_names,
driver_id = worker.task_driver_id.id()
register_actor_signatures(
worker, driver_id, class_id, class_name, actor_method_names,
worker,
driver_id,
class_id,
class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources=actor_creation_resources,
actor_method_cpus=actor_method_cpus)
@ -466,12 +483,14 @@ class ActorMethod(object):
def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead "
"of running 'object.{}()', try "
"'object.{}.remote()'."
.format(self._method_name, self._method_name))
"'object.{}.remote()'.".format(self._method_name,
self._method_name))
def remote(self, *args, **kwargs):
return self._actor._actor_method_call(
self._method_name, args=args, kwargs=kwargs,
self._method_name,
args=args,
kwargs=kwargs,
dependency=self._actor._ray_actor_cursor)
@ -481,12 +500,13 @@ class ActorHandleWrapper(object):
This is essentially just a dictionary, but it is used so that the recipient
can tell that an argument is an ActorHandle.
"""
def __init__(self, actor_id, class_id, actor_handle_id, actor_cursor,
actor_counter, actor_method_names,
actor_method_num_return_vals, method_signatures,
checkpoint_interval, class_name,
actor_creation_dummy_object_id,
actor_creation_resources, actor_method_cpus):
actor_creation_dummy_object_id, actor_creation_resources,
actor_method_cpus):
# TODO(rkn): Some of these fields are probably not necessary. We should
# strip out the unnecessary fields to keep actor handles lightweight.
self.actor_id = actor_id
@ -545,27 +565,20 @@ def unwrap_actor_handle(worker, wrapper):
The unwrapped ActorHandle instance.
"""
driver_id = worker.task_driver_id.id()
register_actor_signatures(worker, driver_id, wrapper.class_id,
wrapper.class_name, wrapper.actor_method_names,
wrapper.actor_method_num_return_vals,
wrapper.actor_creation_resources,
wrapper.actor_method_cpus)
register_actor_signatures(
worker, driver_id, wrapper.class_id, wrapper.class_name,
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
actor_handle_class = make_actor_handle_class(wrapper.class_name)
actor_object = actor_handle_class.__new__(actor_handle_class)
actor_object._manual_init(
wrapper.actor_id,
wrapper.class_id,
wrapper.actor_handle_id,
wrapper.actor_cursor,
wrapper.actor_counter,
wrapper.actor_method_names,
wrapper.actor_method_num_return_vals,
wrapper.method_signatures,
wrapper.checkpoint_interval,
wrapper.actor_id, wrapper.class_id, wrapper.actor_handle_id,
wrapper.actor_cursor, wrapper.actor_counter,
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
wrapper.method_signatures, wrapper.checkpoint_interval,
wrapper.actor_creation_dummy_object_id,
wrapper.actor_creation_resources,
wrapper.actor_method_cpus)
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
return actor_object
@ -612,7 +625,10 @@ def make_actor_handle_class(class_name):
self._ray_actor_creation_resources = actor_creation_resources
self._ray_actor_method_cpus = actor_method_cpus
def _actor_method_call(self, method_name, args=None, kwargs=None,
def _actor_method_call(self,
method_name,
args=None,
kwargs=None,
dependency=None):
"""Method execution stub for an actor handle.
@ -663,7 +679,9 @@ def make_actor_handle_class(class_name):
function_id = compute_actor_method_function_id(
self._ray_class_name, method_name)
object_ids = ray.worker.global_worker.submit_task(
function_id, args, actor_id=self._ray_actor_id,
function_id,
args,
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
@ -722,8 +740,8 @@ def make_actor_handle_class(class_name):
self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID):
# TODO(rkn): Should we be passing in the actor cursor as a
# dependency here?
self._actor_method_call("__ray_terminate__",
args=[self._ray_actor_id.id()])
self._actor_method_call(
"__ray_terminate__", args=[self._ray_actor_id.id()])
return ActorHandle
@ -735,7 +753,6 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
exported = []
class ActorHandle(actor_handle_class):
@classmethod
def remote(cls, *args, **kwargs):
if ray.worker.global_worker.mode is None:
@ -754,11 +771,13 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
actor_cursor = None
# The number of actor method invocations that we've called so far.
actor_counter = 0
# Get the actor methods of the given class.
actor_methods = inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x) or
is_cython(x))))
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
actor_methods = inspect.getmembers(Class, predicate=pred)
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
@ -773,8 +792,9 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
method_signatures[k] = signature.extract_signature(
v, ignore_first=True)
actor_method_names = [method_name for method_name, _ in
actor_methods]
actor_method_names = [
method_name for method_name, _ in actor_methods
]
actor_method_num_return_vals = []
for _, method in actor_methods:
if hasattr(method, "__ray_num_return_vals__"):
@ -796,30 +816,29 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
checkpoint_interval,
ray.worker.global_worker)
exported.append(0)
actor_cursor = export_actor(actor_id, class_id, class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources,
actor_method_cpus,
ray.worker.global_worker)
actor_cursor = export_actor(
actor_id, class_id, class_name, actor_method_names,
actor_method_num_return_vals, actor_creation_resources,
actor_method_cpus, ray.worker.global_worker)
# Increment the actor counter to account for the creation task.
actor_counter += 1
# Instantiate the actor handle.
actor_object = cls.__new__(cls)
actor_object._manual_init(actor_id, class_id, actor_handle_id,
actor_cursor, actor_counter,
actor_method_names,
actor_method_num_return_vals,
method_signatures, checkpoint_interval,
actor_cursor, actor_creation_resources,
actor_method_cpus)
actor_object._manual_init(
actor_id, class_id, actor_handle_id, actor_cursor,
actor_counter, actor_method_names,
actor_method_num_return_vals, method_signatures,
checkpoint_interval, actor_cursor, actor_creation_resources,
actor_method_cpus)
# Call __init__ as a remote function.
if "__init__" in actor_object._ray_actor_method_names:
actor_object._actor_method_call("__init__", args=args,
kwargs=kwargs,
dependency=actor_cursor)
actor_object._actor_method_call(
"__init__",
args=args,
kwargs=kwargs,
dependency=actor_cursor)
else:
print("WARNING: this object has no __init__ method.")

View file

@ -51,25 +51,32 @@ CLUSTER_CONFIG_SCHEMA = {
"idle_timeout_minutes": (int, OPTIONAL),
# Cloud-provider specific configuration.
"provider": ({
"type": (str, REQUIRED), # e.g. aws
"region": (str, OPTIONAL), # e.g. us-east-1
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
"module": (str, OPTIONAL), # module, if using external node provider
}, REQUIRED),
"provider": (
{
"type": (str, REQUIRED), # e.g. aws
"region": (str, OPTIONAL), # e.g. us-east-1
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
"module": (str,
OPTIONAL), # module, if using external node provider
},
REQUIRED),
# How Ray will authenticate with newly launched nodes.
"auth": ({
"ssh_user": (str, REQUIRED), # e.g. ubuntu
"ssh_private_key": (str, OPTIONAL),
}, REQUIRED),
"auth": (
{
"ssh_user": (str, REQUIRED), # e.g. ubuntu
"ssh_private_key": (str, OPTIONAL),
},
REQUIRED),
# Docker configuration. If this is specified, all setup and start commands
# will be executed in the container.
"docker": ({
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
"container_name": (str, OPTIONAL), # e.g., ray_docker
}, OPTIONAL),
"docker": (
{
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
"container_name": (str, OPTIONAL), # e.g., ray_docker
},
OPTIONAL),
# Provider-specific config for the head node, e.g. instance type.
"head_node": (dict, OPTIONAL),
@ -137,9 +144,9 @@ class LoadMetrics(object):
for unwanted_key in unwanted:
del mapping[unwanted_key]
if unwanted:
print(
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
print("Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
prune(self.last_used_time_by_ip)
prune(self.static_resources_by_ip)
prune(self.dynamic_resources_by_ip)
@ -148,10 +155,8 @@ class LoadMetrics(object):
return self._info()["NumNodesUsed"]
def debug_string(self):
return " - {}".format(
"\n - ".join(
["{}: {}".format(k, v)
for k, v in sorted(self._info().items())]))
return " - {}".format("\n - ".join(
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
def _info(self):
nodes_used = 0.0
@ -176,14 +181,19 @@ class LoadMetrics(object):
nodes_used += max_frac
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
return {
"ResourceUsage": ", ".join([
"ResourceUsage":
", ".join([
"{}/{} {}".format(
round(resources_used[rid], 2),
round(resources_total[rid], 2), rid)
for rid in sorted(resources_used)]),
"NumNodesConnected": len(self.static_resources_by_ip),
"NumNodesUsed": round(nodes_used, 2),
"NodeIdleSeconds": "Min={} Mean={} Max={}".format(
for rid in sorted(resources_used)
]),
"NumNodesConnected":
len(self.static_resources_by_ip),
"NumNodesUsed":
round(nodes_used, 2),
"NodeIdleSeconds":
"Min={} Mean={} Max={}".format(
int(np.min(idle_times)) if idle_times else -1,
int(np.mean(idle_times)) if idle_times else -1,
int(np.max(idle_times)) if idle_times else -1),
@ -208,18 +218,20 @@ class StandardAutoscaler(object):
until the target cluster size is met).
"""
def __init__(
self, config_path, load_metrics,
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess, verbose_updates=False,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
def __init__(self,
config_path,
load_metrics,
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess,
verbose_updates=False,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.load_metrics = load_metrics
self.provider = get_node_provider(
self.config["provider"], self.config["cluster_name"])
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
self.max_failures = max_failures
self.max_concurrent_launches = max_concurrent_launches
@ -245,9 +257,8 @@ class StandardAutoscaler(object):
self.reload_config(errors_fatal=False)
self._update()
except Exception as e:
print(
"StandardAutoscaler: Error during autoscaling: {}",
traceback.format_exc())
print("StandardAutoscaler: Error during autoscaling: {}",
traceback.format_exc())
self.num_failures += 1
if self.num_failures > self.max_failures:
print("*** StandardAutoscaler: Too many errors, abort. ***")
@ -274,15 +285,13 @@ class StandardAutoscaler(object):
if node_ip in last_used and last_used[node_ip] < horizon and \
len(nodes) - num_terminated > self.config["min_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
print("StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
elif not self.launch_config_ok(node_id):
num_terminated += 1
print(
"StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
print("StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
if num_terminated > 0:
nodes = self.workers()
@ -292,9 +301,8 @@ class StandardAutoscaler(object):
num_terminated = 0
while len(nodes) > self.config["max_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
print("StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
self.provider.terminate_node(nodes[-1])
nodes = nodes[:-1]
if num_terminated > 0:
@ -339,13 +347,13 @@ class StandardAutoscaler(object):
with open(self.config_path) as f:
new_config = yaml.load(f.read())
validate_config(new_config)
new_launch_hash = hash_launch_conf(
new_config["worker_nodes"], new_config["auth"])
new_runtime_hash = hash_runtime_conf(
new_config["file_mounts"],
[new_config["setup_commands"],
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]])
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
new_config["auth"])
new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [
new_config["setup_commands"],
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]
])
self.config = new_config
self.launch_hash = new_launch_hash
self.runtime_hash = new_runtime_hash
@ -353,17 +361,15 @@ class StandardAutoscaler(object):
if errors_fatal:
raise e
else:
print(
"StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
print("StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
def target_num_workers(self):
target_frac = self.config["target_utilization_fraction"]
cur_used = self.load_metrics.approx_workers_used()
ideal_num_workers = int(np.ceil(cur_used / float(target_frac)))
return min(
self.config["max_workers"],
max(self.config["min_workers"], ideal_num_workers))
return min(self.config["max_workers"],
max(self.config["min_workers"], ideal_num_workers))
def launch_config_ok(self, node_id):
launch_conf = self.provider.node_tags(node_id).get(
@ -393,8 +399,7 @@ class StandardAutoscaler(object):
node_id,
self.config["provider"],
self.config["auth"],
self.config["cluster_name"],
{},
self.config["cluster_name"], {},
with_head_node_ip(self.config["worker_start_ray_commands"]),
self.runtime_hash,
redirect_output=not self.verbose_updates,
@ -409,14 +414,12 @@ class StandardAutoscaler(object):
return
if self.config.get("no_restart", False) and \
self.num_successful_updates.get(node_id, 0) > 0:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"])
init_commands = (self.config["setup_commands"] +
self.config["worker_setup_commands"])
else:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
init_commands = (self.config["setup_commands"] +
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
updater = self.node_updater_cls(
node_id,
self.config["provider"],
@ -445,14 +448,12 @@ class StandardAutoscaler(object):
print("StandardAutoscaler: Launching {} new nodes".format(count))
num_before = len(self.workers())
self.provider.create_node(
self.config["worker_nodes"],
{
self.config["worker_nodes"], {
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
TAG_RAY_NODE_TYPE: "Worker",
TAG_RAY_NODE_STATUS: "Uninitialized",
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
},
count)
}, count)
# TODO(ekl) be less conservative in this check
assert len(self.workers()) > num_before, \
"Num nodes failed to increase after creating a new node"
@ -472,8 +473,8 @@ class StandardAutoscaler(object):
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
datetime.now(), len(nodes), self.target_num_workers(),
suffix, self.load_metrics.debug_string())
datetime.now(), len(nodes), self.target_num_workers(), suffix,
self.load_metrics.debug_string())
def typename(v):
@ -507,9 +508,8 @@ def check_extraneous(config, schema):
raise ValueError("Config {} is not a dictionary".format(config))
for k in config:
if k not in schema:
raise ValueError(
"Unexpected config key `{}` not in {}".format(
k, list(schema.keys())))
raise ValueError("Unexpected config key `{}` not in {}".format(
k, list(schema.keys())))
v, kreq = schema[k]
if v is None:
continue
@ -517,7 +517,8 @@ def check_extraneous(config, schema):
if not isinstance(config[k], v):
raise ValueError(
"Config key `{}` has wrong type {}, expected {}".format(
k, type(config[k]).__name__, v.__name__))
k,
type(config[k]).__name__, v.__name__))
else:
check_extraneous(config[k], v)

View file

@ -25,12 +25,10 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
def key_pair(i, region):
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
if i == 0:
return (
"{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return (
"{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
return ("{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return ("{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
# Suppress excessive connection dropped logs from boto
@ -83,7 +81,9 @@ def _configure_iam_role(config):
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": "ec2.amazonaws.com"},
"Principal": {
"Service": "ec2.amazonaws.com"
},
"Action": "sts:AssumeRole",
},
],
@ -97,8 +97,7 @@ def _configure_iam_role(config):
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation
print("Role not specified for head node, using {}".format(
profile.arn))
print("Role not specified for head node, using {}".format(profile.arn))
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
return config
@ -146,8 +145,10 @@ def _configure_key_pair(config):
def _configure_subnet(config):
ec2 = _resource("ec2", config)
subnets = sorted(
[s for s in ec2.subnets.all()
if s.state == "available" and s.map_public_ip_on_launch],
[
s for s in ec2.subnets.all()
if s.state == "available" and s.map_public_ip_on_launch
],
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
if not subnets:
@ -157,9 +158,9 @@ def _configure_subnet(config):
"and trying this again. Note that the subnet must map public IPs "
"on instance launch.")
if "availability_zone" in config["provider"]:
default_subnet = next((s for s in subnets
if s.availability_zone ==
config["provider"]["availability_zone"]),
default_subnet = next((
s for s in subnets
if s.availability_zone == config["provider"]["availability_zone"]),
None)
if not default_subnet:
raise Exception(
@ -209,11 +210,21 @@ def _configure_security_group(config):
if not security_group.ip_permissions:
security_group.authorize_ingress(
IpPermissions=[
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
IpPermissions=[{
"FromPort": -1,
"ToPort": -1,
"IpProtocol": "-1",
"UserIdGroupPairs": [{
"GroupId": security_group.id
}]
}, {
"FromPort": 22,
"ToPort": 22,
"IpProtocol": "TCP",
"IpRanges": [{
"CidrIp": "0.0.0.0/0"
}]
}])
if "SecurityGroupIds" not in config["head_node"]:
print("SecurityGroupIds not specified for head node, using {}".format(
@ -231,8 +242,10 @@ def _configure_security_group(config):
def _get_subnet_or_die(config, subnet_id):
ec2 = _resource("ec2", config)
subnet = list(
ec2.subnets.filter(Filters=[
{"Name": "subnet-id", "Values": [subnet_id]}]))
ec2.subnets.filter(Filters=[{
"Name": "subnet-id",
"Values": [subnet_id]
}]))
assert len(subnet) == 1, "Subnet not found"
subnet = subnet[0]
return subnet
@ -241,8 +254,10 @@ def _get_subnet_or_die(config, subnet_id):
def _get_security_group(config, vpc_id, group_name):
ec2 = _resource("ec2", config)
existing_groups = list(
ec2.security_groups.filter(Filters=[
{"Name": "vpc-id", "Values": [vpc_id]}]))
ec2.security_groups.filter(Filters=[{
"Name": "vpc-id",
"Values": [vpc_id]
}]))
for sg in existing_groups:
if sg.group_name == group_name:
return sg
@ -270,8 +285,10 @@ def _get_instance_profile(profile_name, config):
def _get_key(key_name, config):
ec2 = _resource("ec2", config)
for key in ec2.key_pairs.filter(
Filters=[{"Name": "key-name", "Values": [key_name]}]):
for key in ec2.key_pairs.filter(Filters=[{
"Name": "key-name",
"Values": [key_name]
}]):
if key.name == key_name:
return key

View file

@ -84,7 +84,8 @@ class AWSNodeProvider(NodeProvider):
tag_pairs = []
for k, v in tags.items():
tag_pairs.append({
"Key": k, "Value": v,
"Key": k,
"Value": v,
})
node.create_tags(Tags=tag_pairs)
@ -95,20 +96,20 @@ class AWSNodeProvider(NodeProvider):
"Value": self.cluster_name,
}]
for k, v in tags.items():
tag_pairs.append(
{
"Key": k,
"Value": v,
})
tag_pairs.append({
"Key": k,
"Value": v,
})
conf.update({
"MinCount": 1,
"MaxCount": count,
"TagSpecifications": conf.get("TagSpecifications", []) + [
{
"ResourceType": "instance",
"Tags": tag_pairs,
}
]
"MinCount":
1,
"MaxCount":
count,
"TagSpecifications":
conf.get("TagSpecifications", []) + [{
"ResourceType": "instance",
"Tags": tag_pairs,
}]
})
self.ec2.create_instances(**conf)

View file

@ -23,9 +23,8 @@ from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
from ray.autoscaler.updater import NodeUpdaterProcess
def create_or_update_cluster(
config_file, override_min_workers, override_max_workers,
no_restart, yes):
def create_or_update_cluster(config_file, override_min_workers,
override_max_workers, no_restart, yes):
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.load(open(config_file).read())
@ -39,8 +38,8 @@ def create_or_update_cluster(
importer = NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
raise NotImplementedError(
"Unsupported provider {}".format(config["provider"]))
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
bootstrap_config, _ = importer()
config = bootstrap_config(config)
@ -129,8 +128,10 @@ def get_or_create_head_node(config, no_restart, yes):
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
remote_key_path:
config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml":
remote_config_file.name
})
if no_restart:
@ -160,30 +161,24 @@ def get_or_create_head_node(config, no_restart, yes):
print("Error: updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
print(
"Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
print("Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
monitor_str = "tail -f /tmp/raylogs/monitor-*"
for s in init_commands:
if ("ray start" in s and "docker exec" in s and
"--autoscaling-config" in s):
if ("ray start" in s and "docker exec" in s
and "--autoscaling-config" in s):
monitor_str = "docker exec {} /bin/sh -c {}".format(
config["docker"]["container_name"],
quote(monitor_str))
print(
"To monitor auto-scaling activity, you can run:\n\n"
" ssh -i {} {}@{} {}\n".format(
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node),
quote(monitor_str)))
print(
"To login to the cluster, run:\n\n"
" ssh -i {} {}@{}\n".format(
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
config["docker"]["container_name"], quote(monitor_str))
print("To monitor auto-scaling activity, you can run:\n\n"
" ssh -i {} {}@{} {}\n".format(config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node),
quote(monitor_str)))
print("To login to the cluster, run:\n\n"
" ssh -i {} {}@{}\n".format(config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
def get_head_node_ip(config_file):

View file

@ -22,24 +22,21 @@ def dockerize_if_needed(config):
assert cname, "Must provide container name!"
docker_mounts = {dst: dst for dst in config["file_mounts"]}
config["setup_commands"] = (
docker_install_cmds() +
docker_start_cmds(
config["auth"]["ssh_user"], docker_image,
docker_mounts, cname) +
with_docker_exec(
config["setup_commands"], container_name=cname))
docker_install_cmds() + docker_start_cmds(
config["auth"]["ssh_user"], docker_image, docker_mounts, cname) +
with_docker_exec(config["setup_commands"], container_name=cname))
config["head_setup_commands"] = with_docker_exec(
config["head_setup_commands"], container_name=cname)
config["head_start_ray_commands"] = (
docker_autoscaler_setup(cname) +
with_docker_exec(
docker_autoscaler_setup(cname) + with_docker_exec(
config["head_start_ray_commands"], container_name=cname))
config["worker_setup_commands"] = with_docker_exec(
config["worker_setup_commands"], container_name=cname)
config["worker_start_ray_commands"] = with_docker_exec(
config["worker_start_ray_commands"], container_name=cname,
config["worker_start_ray_commands"],
container_name=cname,
env_vars=["RAY_HEAD_IP"])
return config
@ -50,21 +47,24 @@ def with_docker_exec(cmds, container_name, env_vars=None):
if env_vars:
env_str = " ".join(
["-e {env}=${env}".format(env=env) for env in env_vars])
return ["docker exec {} {} /bin/sh -c {} ".format(
env_str, container_name, quote(cmd)) for cmd in cmds]
return [
"docker exec {} {} /bin/sh -c {} ".format(env_str, container_name,
quote(cmd)) for cmd in cmds
]
def docker_install_cmds():
return [aptwait_cmd() + " && sudo apt-get update",
aptwait_cmd() + " && sudo apt-get install -y docker.io"]
return [
aptwait_cmd() + " && sudo apt-get update",
aptwait_cmd() + " && sudo apt-get install -y docker.io"
]
def aptwait_cmd():
return (
"while sudo fuser"
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
" >/dev/null 2>&1; "
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
return ("while sudo fuser"
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
" >/dev/null 2>&1; "
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
def docker_start_cmds(user, image, mount, cname):
@ -77,10 +77,12 @@ def docker_start_cmds(user, image, mount, cname):
# create flags
# ports for the redis, object manager, and tune client
port_flags = " ".join(["-p {port}:{port}".format(port=port)
for port in ["6379", "8076", "4321"]])
mount_flags = " ".join(["-v {src}:{dest}".format(src=k, dest=v)
for k, v in mount.items()])
port_flags = " ".join([
"-p {port}:{port}".format(port=port)
for port in ["6379", "8076", "4321"]
])
mount_flags = " ".join(
["-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items()])
# for click, used in ray cli
env_vars = {"LC_ALL": "C.UTF-8", "LANG": "C.UTF-8"}
@ -88,9 +90,10 @@ def docker_start_cmds(user, image, mount, cname):
["-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items()])
# docker run command
docker_run = ["docker", "run", "--rm", "--name {}".format(cname),
"-d", "-it", port_flags, mount_flags, env_flags,
"--net=host", image, "bash"]
docker_run = [
"docker", "run", "--rm", "--name {}".format(cname), "-d", "-it",
port_flags, mount_flags, env_flags, "--net=host", image, "bash"
]
cmds.append(" ".join(docker_run))
docker_update = []
docker_update.append("apt-get -y update")
@ -107,7 +110,8 @@ def docker_autoscaler_setup(cname):
base_path = os.path.basename(path)
cmds.append("docker cp {path} {cname}:{dpath}".format(
path=path, dpath=base_path, cname=cname))
cmds.extend(with_docker_exec(
["cp {} {}".format("/" + base_path, path)],
container_name=cname))
cmds.extend(
with_docker_exec(
["cp {} {}".format("/" + base_path, path)],
container_name=cname))
return cmds

View file

@ -15,14 +15,15 @@ def import_aws():
def load_aws_config():
import ray.autoscaler.aws as ray_aws
return os.path.join(os.path.dirname(
ray_aws.__file__), "example-full.yaml")
return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml")
def import_external():
"""Mock a normal provider importer."""
def return_it_back(config):
return config
return return_it_back, None
@ -55,8 +56,7 @@ def load_class(path):
class_data = path.split(".")
if len(class_data) < 2:
raise ValueError(
"You need to pass a valid path like mymodule.provider_class"
)
"You need to pass a valid path like mymodule.provider_class")
module_path = ".".join(class_data[:-1])
class_str = class_data[-1]
module = importlib.import_module(module_path)
@ -71,8 +71,8 @@ def get_node_provider(provider_config, cluster_name):
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
_, provider_cls = importer()
return provider_cls(provider_config, cluster_name)
@ -82,8 +82,8 @@ def get_default_config(provider_config):
return {}
load_config = DEFAULT_CONFIGS.get(provider_config["type"])
if load_config is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
path_to_default = load_config()
with open(path_to_default) as f:
defaults = yaml.load(f)

View file

@ -1,7 +1,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""The Ray autoscaler uses tags to associate metadata with instances."""
# Tag for the name of the node

View file

@ -26,10 +26,16 @@ def pretty_cmd(cmd_str):
class NodeUpdater(object):
"""A process for syncing files and running init commands on a node."""
def __init__(
self, node_id, provider_config, auth_config, cluster_name,
file_mounts, setup_cmds, runtime_hash, redirect_output=True,
process_runner=subprocess):
def __init__(self,
node_id,
provider_config,
auth_config,
cluster_name,
file_mounts,
setup_cmds,
runtime_hash,
redirect_output=True,
process_runner=subprocess):
self.daemon = True
self.process_runner = process_runner
self.provider = get_node_provider(provider_config, cluster_name)
@ -66,13 +72,12 @@ class NodeUpdater(object):
"NodeUpdater: Error updating {}"
"See {} for remote logs.".format(error_str, self.output_name),
file=self.stdout)
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "UpdateFailed"})
if self.logfile is not None:
print(
"----- BEGIN REMOTE LOGS -----\n" +
open(self.logfile.name).read() +
"\n----- END REMOTE LOGS -----")
print("----- BEGIN REMOTE LOGS -----\n" + open(
self.logfile.name).read() + "\n----- END REMOTE LOGS -----"
)
raise e
self.provider.set_node_tags(
self.node_id, {
@ -85,8 +90,8 @@ class NodeUpdater(object):
file=self.stdout)
def do_update(self):
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "WaitingForSSH"})
deadline = time.time() + NODE_START_WAIT_S
# Wait for external IP
@ -114,7 +119,8 @@ class NodeUpdater(object):
raise Exception("Node not running yet...")
self.ssh_cmd(
"uptime",
connect_timeout=5, redirect=open("/dev/null", "w"))
connect_timeout=5,
redirect=open("/dev/null", "w"))
ssh_ok = True
except Exception as e:
retry_str = str(e)
@ -130,8 +136,8 @@ class NodeUpdater(object):
assert ssh_ok, "Unable to SSH to node"
# Rsync file mounts
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "SyncingFiles"})
for remote_path, local_path in self.file_mounts.items():
print(
"NodeUpdater: Syncing {} to {}...".format(
@ -143,18 +149,20 @@ class NodeUpdater(object):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)))
self.process_runner.check_call([
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
"--delete", "-avz", "{}".format(local_path),
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
], stdout=self.stdout, stderr=self.stderr)
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
self.process_runner.check_call(
[
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
"--delete", "-avz", "{}".format(local_path),
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
],
stdout=self.stdout,
stderr=self.stderr)
# Run init commands
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "SettingUp"})
for cmd in self.setup_cmds:
self.ssh_cmd(cmd, verbose=True)
@ -165,13 +173,16 @@ class NodeUpdater(object):
pretty_cmd(cmd), self.ssh_ip),
file=self.stdout)
force_interactive = "set -i && source ~/.bashrc && "
self.process_runner.check_call([
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key,
"{}@{}".format(self.ssh_user, self.ssh_ip),
"bash --login -c {}".format(pipes.quote(force_interactive + cmd))
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
self.process_runner.check_call(
[
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key, "{}@{}".format(
self.ssh_user, self.ssh_ip), "bash --login -c {}".format(
pipes.quote(force_interactive + cmd))
],
stdout=redirect or self.stdout,
stderr=redirect or self.stderr)
class NodeUpdaterProcess(NodeUpdater, Process):

View file

@ -25,7 +25,7 @@ OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
assert(numbytes == 4 or numbytes == 8)
assert (numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
@ -50,7 +50,6 @@ def get_next_message(pubsub_client, timeout_seconds=10):
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
@ -192,16 +191,16 @@ class TestGlobalStateStore(unittest.TestCase):
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_object = (SubscribeToNotificationsReply.
GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i),
manager_ids[i])
self.assertEqual(
notification_object.ManagerIds(i), manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
@ -215,10 +214,9 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
check_object_notification(
get_next_message(p)["data"], b"object_id1", data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to
@ -232,26 +230,22 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
check_object_notification(
get_next_message(p)["data"], b"object_id2", data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2",
b"manager_id3"])
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
@ -349,8 +343,7 @@ class TestGlobalStateStore(unittest.TestCase):
# update happens, and the response is still the same task.
task_args = [task_args[0]] + task_args
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@ -362,8 +355,7 @@ class TestGlobalStateStore(unittest.TestCase):
# task.
task_args[1] = TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@ -375,8 +367,7 @@ class TestGlobalStateStore(unittest.TestCase):
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
@ -388,8 +379,7 @@ class TestGlobalStateStore(unittest.TestCase):
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that does not match the current value,
@ -399,8 +389,7 @@ class TestGlobalStateStore(unittest.TestCase):
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@ -409,8 +398,10 @@ class TestGlobalStateStore(unittest.TestCase):
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"]
task_args = [
b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"
]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
@ -418,8 +409,7 @@ class TestGlobalStateStore(unittest.TestCase):
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), task_args[0])
self.assertEqual(notification_object.State(), task_args[1])
self.assertEqual(notification_object.LocalSchedulerId(),
task_args[2])
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
self.assertEqual(notification_object.ExecutionDependencies(),
task_args[3])
self.assertEqual(notification_object.TaskSpec(), task_args[-1])

View file

@ -30,19 +30,23 @@ def random_task_id():
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
990 * u"h"]
990 * u"h"
]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
BASE_SIMPLE_OBJECTS += [
long(0), # noqa: E501,F821
long(1), # noqa: E501,F821
long(100000), # noqa: E501,F821
long(1 << 100) # noqa: E501,F821
]
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS]
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS +
DICT_SIMPLE_OBJECTS)
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS)
# Create some complex objects that cannot be serialized by value in tasks.
@ -55,21 +59,20 @@ class Foo(object):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", lst, Foo(),
10 * [10 * [10 * [1]]]]
BASE_COMPLEX_OBJECTS = [
999 * "h", 999 * u"h", lst,
Foo(), 10 * [10 * [10 * [1]]]
]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS]
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS +
DICT_COMPLEX_OBJECTS)
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
for val in SIMPLE_OBJECTS:
@ -79,7 +82,6 @@ class TestSerialization(unittest.TestCase):
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
random_object_id()
@ -95,6 +97,7 @@ class TestObjectID(unittest.TestCase):
def h():
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
@ -113,10 +116,12 @@ class TestObjectID(unittest.TestCase):
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids1 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
object_ids2 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
@ -129,7 +134,6 @@ class TestObjectID(unittest.TestCase):
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
@ -148,31 +152,17 @@ class TestTask(unittest.TestCase):
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"],
10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args,

View file

@ -5,5 +5,7 @@ from __future__ import print_function
from .tfutils import TensorFlowVariables
from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe
__all__ = ["TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe"]
__all__ = [
"TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe"
]

View file

@ -8,6 +8,8 @@ from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
triu, tril, blockwise_dot, dot, transpose, add, subtract,
numpy_to_dist, subblocks)
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
__all__ = [
"random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones",
"copy", "eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add",
"subtract", "numpy_to_dist", "subblocks"
]

View file

@ -13,8 +13,9 @@ class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
for a in self.shape]
self.num_blocks = [
int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape
]
if objectids is not None:
self.objectids = objectids
else:
@ -56,7 +57,7 @@ class DistArray(object):
def assemble(self):
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
first_block = ray.get(self.objectids[(0, ) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
@ -85,8 +86,8 @@ def numpy_to_dist(a):
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
result.objectids[index] = ray.put(
a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@ -126,12 +127,11 @@ def eye(dim1, dim2=-1, dtype_name="float"):
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0],
block_shape[1],
dtype_name=dtype_name)
result.objectids[i, j] = ra.eye.remote(
block_shape[0], block_shape[1], dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
result.objectids[i, j] = ra.zeros.remote(
block_shape, dtype_name=dtype_name)
return result
@ -190,8 +190,8 @@ def dot(a, b):
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
"a.shape = {} and b.shape = {}.".format(a.shape,
b.shape))
"a.shape = {} and b.shape = {}.".format(
a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
@ -227,8 +227,8 @@ def subblocks(a, *ranges):
"the {}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must "
"be at least 0, but the {}th range is {}."
.format(i, ranges[i]))
"be at least 0, but the {}th range is {}.".format(
i, ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must "
"be less than the relevant number of blocks, but "
@ -240,8 +240,8 @@ def subblocks(a, *ranges):
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
result.objectids[index] = a.objectids[tuple(
[ranges[i][index[i]] for i in range(a.ndim)])]
return result
@ -249,8 +249,8 @@ def subblocks(a, *ranges):
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, "
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
a.shape))
"but a.ndim = {}, a.shape = {}.".format(
a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
@ -263,8 +263,8 @@ def transpose(a):
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
"shape, but x1.shape = {}, and x2.shape = {}.".format(
x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index],

View file

@ -76,9 +76,10 @@ def tsqr(a):
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_block_current = ra.dot.remote(q_block_current,
ra.subarray.remote(
q_tree[ith_index, j], lower,
upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
@ -196,8 +197,8 @@ def qr(a):
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
dtype_name=result_dtype)
eye_temp = ra.eye.remote(
R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
@ -220,10 +221,11 @@ def qr(a):
for i in range(len(Ts))[::-1]:
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(
Ts[i],
core.dot.remote(core.transpose.remote(y_col_block), q))))
q,
core.dot.remote(y_col_block,
core.dot.remote(
Ts[i],
core.dot.remote(
core.transpose.remote(y_col_block), q))))
return ray.get(q), r_res

View file

@ -8,6 +8,8 @@ from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
copy, tril, triu, diag, transpose, add, subtract, sum,
shape, sum_list)
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
__all__ = [
"random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", "vstack",
"hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add",
"subtract", "sum", "shape", "sum_list"
]

View file

@ -5,10 +5,11 @@ from __future__ import print_function
import numpy as np
import ray
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"multi_dot"]
__all__ = [
"matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "cholesky",
"eigvals", "eigvalsh", "pinv", "slogdet", "det", "svd", "eig", "eigh",
"lstsq", "norm", "qr", "cond", "matrix_rank", "multi_dot"
]
@ray.remote

View file

@ -69,14 +69,14 @@ def flush_task_and_object_metadata_unsafe():
for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
num_object_keys_deleted += redis_client.delete(key)
print("Deleted {} object info keys from Redis.".format(
num_object_keys_deleted))
num_object_keys_deleted))
# Flush the object locations.
num_object_location_keys_deleted = 0
for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"):
num_object_location_keys_deleted += redis_client.delete(key)
print("Deleted {} object location keys from Redis.".format(
num_object_location_keys_deleted))
num_object_location_keys_deleted))
# Loop over the shards and flush all of them.
for redis_client in ray.worker.global_state.redis_clients:

View file

@ -59,6 +59,7 @@ class GlobalState(object):
Attributes:
redis_client: The redis client used to query the redis server.
"""
def __init__(self):
"""Create a GlobalState object."""
# The redis server storing metadata, such as function table, client
@ -82,7 +83,9 @@ class GlobalState(object):
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
def _initialize_global_state(self, redis_ip_address, redis_port,
def _initialize_global_state(self,
redis_ip_address,
redis_port,
timeout=20):
"""Initialize the GlobalState object by connecting to Redis.
@ -97,8 +100,8 @@ class GlobalState(object):
timeout: The maximum amount of time (in seconds) that we should
wait for the keys in Redis to be populated.
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
start_time = time.time()
@ -118,8 +121,8 @@ class GlobalState(object):
"{}.".format(num_redis_shards))
# Attempt to get all of the Redis shards.
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
end=-1)
ip_address_ports = self.redis_client.lrange(
"RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
print("Waiting longer for RedisShards to be populated.")
time.sleep(1)
@ -132,15 +135,15 @@ class GlobalState(object):
if time.time() - start_time >= timeout:
raise Exception("Timed out while attempting to initialize the "
"global state. num_redis_shards = {}, "
"ip_address_ports = {}"
.format(num_redis_shards, ip_address_ports))
"ip_address_ports = {}".format(
num_redis_shards, ip_address_ports))
# Get the rest of the information.
self.redis_clients = []
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
self.redis_clients.append(
redis.StrictRedis(host=shard_address, port=shard_port))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
@ -152,8 +155,8 @@ class GlobalState(object):
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
client = self.redis_clients[key.redis_shard_hash() % len(
self.redis_clients)]
return client.execute_command(*args)
def _keys(self, pattern):
@ -189,8 +192,9 @@ class GlobalState(object):
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
manager_ids = [
binary_to_hex(manager_id) for manager_id in object_locations
]
else:
manager_ids = None
@ -199,11 +203,13 @@ class GlobalState(object):
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
result = {"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())}
result = {
"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())
}
return result
@ -227,9 +233,10 @@ class GlobalState(object):
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):]
for key in object_location_keys])
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [
key[len(OBJECT_LOCATION_PREFIX):]
for key in object_location_keys
])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = (
@ -254,26 +261,37 @@ class GlobalState(object):
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
0)
task_table_message = TaskReply.GetRootAsTaskReply(
task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec = ray.local_scheduler.task_from_string(task_spec)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
"ParentCounter": task_spec.parent_counter(),
"ActorID": binary_to_hex(task_spec.actor_id().id()),
"DriverID":
binary_to_hex(task_spec.driver_id().id()),
"TaskID":
binary_to_hex(task_spec.task_id().id()),
"ParentTaskID":
binary_to_hex(task_spec.parent_task_id().id()),
"ParentCounter":
task_spec.parent_counter(),
"ActorID":
binary_to_hex(task_spec.actor_id().id()),
"ActorCreationID":
binary_to_hex(task_spec.actor_creation_id().id()),
binary_to_hex(task_spec.actor_creation_id().id()),
"ActorCreationDummyObjectID":
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter": task_spec.actor_counter(),
"FunctionID": binary_to_hex(task_spec.function_id().id()),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources()}
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter":
task_spec.actor_counter(),
"FunctionID":
binary_to_hex(task_spec.function_id().id()),
"Args":
task_spec.arguments(),
"ReturnObjectIDs":
task_spec.returns(),
"RequiredResources":
task_spec.required_resources()
}
execution_dependencies_message = (
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
@ -282,21 +300,27 @@ class GlobalState(object):
ray.local_scheduler.ObjectID(
execution_dependencies_message.ExecutionDependencies(i))
for i in range(
execution_dependencies_message.ExecutionDependenciesLength())]
execution_dependencies_message.ExecutionDependenciesLength())
]
# TODO(rkn): The return fields ExecutionDependenciesString and
# ExecutionDependencies are redundant, so we should remove
# ExecutionDependencies. However, it is currently used in monitor.py.
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString":
task_table_message.ExecutionDependencies(),
"ExecutionDependencies": execution_dependencies,
"SpillbackCount":
task_table_message.SpillbackCount(),
"TaskSpec": task_spec_info}
return {
"State":
task_table_message.State(),
"LocalSchedulerID":
binary_to_hex(task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString":
task_table_message.ExecutionDependencies(),
"ExecutionDependencies":
execution_dependencies,
"SpillbackCount":
task_table_message.SpillbackCount(),
"TaskSpec":
task_spec_info
}
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
@ -337,7 +361,8 @@ class GlobalState(object):
function_info_parsed = {
"DriverID": binary_to_hex(info[b"driver_id"]),
"Module": decode(info[b"module"]),
"Name": decode(info[b"name"])}
"Name": decode(info[b"name"])
}
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
return results
@ -469,21 +494,17 @@ class GlobalState(object):
if start is None and end is None:
if fwd:
event_list = self.redis_client.zrange(
event_log_set,
**params)
event_log_set, **params)
else:
event_list = self.redis_client.zrevrange(
event_log_set,
**params)
event_log_set, **params)
else:
if fwd:
event_list = self.redis_client.zrangebyscore(
event_log_set,
**params)
event_log_set, **params)
else:
event_list = self.redis_client.zrevrangebyscore(
event_log_set,
**params)
event_log_set, **params)
for (event, score) in event_list:
event_dict = json.loads(event.decode())
@ -503,11 +524,11 @@ class GlobalState(object):
task_info[task_id]["get_task_start"] = event[0]
if event[1] == "ray:get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 1):
if (event[1] == "ray:import_remote_function"
and event[2] == 1):
task_info[task_id]["import_remote_start"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 2):
if (event[1] == "ray:import_remote_function"
and event[2] == 2):
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
@ -547,7 +568,6 @@ class GlobalState(object):
breakdowns=True,
task_dep=True,
obj_dep=True):
"""Dump task profiling information to a file.
This information can be viewed as a timeline of profiling information
@ -604,72 +624,103 @@ class GlobalState(object):
# modify it in place since we will use the original values later.
total_info = copy.copy(task_table[task_id]["TaskSpec"])
total_info["Args"] = [
oid.hex() if isinstance(oid, ray.local_scheduler.ObjectID)
else oid for oid in task_t_info["TaskSpec"]["Args"]]
oid.hex()
if isinstance(oid, ray.local_scheduler.ObjectID) else oid
for oid in task_t_info["TaskSpec"]["Args"]
]
total_info["ReturnObjectIDs"] = [
oid.hex() for oid
in task_t_info["TaskSpec"]["ReturnObjectIDs"]]
oid.hex() for oid in task_t_info["TaskSpec"]["ReturnObjectIDs"]
]
total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"]
total_info["get_arguments"] = (info["get_arguments_end"] -
info["get_arguments_start"])
total_info["execute"] = (info["execute_end"] -
info["execute_start"])
total_info["store_outputs"] = (info["store_outputs_end"] -
info["store_outputs_start"])
total_info["get_arguments"] = (
info["get_arguments_end"] - info["get_arguments_start"])
total_info["execute"] = (
info["execute_end"] - info["execute_start"])
total_info["store_outputs"] = (
info["store_outputs_end"] - info["store_outputs_start"])
total_info["function_name"] = info["function_name"]
total_info["worker_id"] = info["worker_id"]
parent_info = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
task_table[task_id]["TaskSpec"]["ParentTaskID"])
worker = workers[info["worker_id"]]
# The catapult trace format documentation can be found here:
# https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview # noqa: E501
if breakdowns:
if "get_arguments_end" in info:
get_args_trace = {
"cat": "get_arguments",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["get_arguments_start"]),
"ph": "X",
"name": info["function_name"] + ":get_arguments",
"args": total_info,
"dur": micros(info["get_arguments_end"] -
info["get_arguments_start"]),
"cname": "rail_idle"
"cat":
"get_arguments",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"X",
"name":
info["function_name"] + ":get_arguments",
"args":
total_info,
"dur":
micros(info["get_arguments_end"] -
info["get_arguments_start"]),
"cname":
"rail_idle"
}
full_trace.append(get_args_trace)
if "store_outputs_end" in info:
outputs_trace = {
"cat": "store_outputs",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["store_outputs_start"]),
"ph": "X",
"name": info["function_name"] + ":store_outputs",
"args": total_info,
"dur": micros(info["store_outputs_end"] -
info["store_outputs_start"]),
"cname": "thread_state_runnable"
"cat":
"store_outputs",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["store_outputs_start"]),
"ph":
"X",
"name":
info["function_name"] + ":store_outputs",
"args":
total_info,
"dur":
micros(info["store_outputs_end"] -
info["store_outputs_start"]),
"cname":
"thread_state_runnable"
}
full_trace.append(outputs_trace)
if "execute_end" in info:
execute_trace = {
"cat": "execute",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["execute_start"]),
"ph": "X",
"name": info["function_name"] + ":execute",
"args": total_info,
"dur": micros(info["execute_end"] -
info["execute_start"]),
"cname": "rail_animation"
"cat":
"execute",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["execute_start"]),
"ph":
"X",
"name":
info["function_name"] + ":execute",
"args":
total_info,
"dur":
micros(info["execute_end"] - info["execute_start"]),
"cname":
"rail_animation"
}
full_trace.append(execute_trace)
@ -680,15 +731,20 @@ class GlobalState(object):
parent_profile = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
parent = {
"cat": "submit_task",
"pid": "Node " + parent_worker["node_ip_address"],
"tid": parent_info["worker_id"],
"ts": micros_rel(
parent_profile and
parent_profile["get_arguments_start"] or
start_time),
"ph": "s",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + parent_worker["node_ip_address"],
"tid":
parent_info["worker_id"],
"ts":
micros_rel(parent_profile
and parent_profile["get_arguments_start"]
or start_time),
"ph":
"s",
"name":
"SubmitTask",
"args": {},
"id": (parent_info["worker_id"] +
str(micros(min(parent_times))))
@ -696,32 +752,50 @@ class GlobalState(object):
full_trace.append(parent)
task_trace = {
"cat": "submit_task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(info["get_arguments_start"]),
"ph": "f",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"f",
"name":
"SubmitTask",
"args": {},
"id": (info["worker_id"] +
str(micros(min(parent_times)))),
"bp": "e",
"cname": "olive"
"id":
(info["worker_id"] + str(micros(min(parent_times)))),
"bp":
"e",
"cname":
"olive"
}
full_trace.append(task_trace)
task = {
"cat": "task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["get_arguments_start"]),
"ph": "X",
"name": info["function_name"],
"args": total_info,
"dur": micros(info["store_outputs_end"] -
info["get_arguments_start"]),
"cname": "thread_state_runnable"
"cat":
"task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"X",
"name":
info["function_name"],
"args":
total_info,
"dur":
micros(info["store_outputs_end"] -
info["get_arguments_start"]),
"cname":
"thread_state_runnable"
}
full_trace.append(task)
@ -732,15 +806,20 @@ class GlobalState(object):
parent_profile = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
parent = {
"cat": "submit_task",
"pid": "Node " + parent_worker["node_ip_address"],
"tid": parent_info["worker_id"],
"ts": micros_rel(
parent_profile and
parent_profile["get_arguments_start"] or
start_time),
"ph": "s",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + parent_worker["node_ip_address"],
"tid":
parent_info["worker_id"],
"ts":
micros_rel(parent_profile
and parent_profile["get_arguments_start"]
or start_time),
"ph":
"s",
"name":
"SubmitTask",
"args": {},
"id": (parent_info["worker_id"] +
str(micros(min(parent_times))))
@ -748,16 +827,23 @@ class GlobalState(object):
full_trace.append(parent)
task_trace = {
"cat": "submit_task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(info["get_arguments_start"]),
"ph": "f",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"f",
"name":
"SubmitTask",
"args": {},
"id": (info["worker_id"] +
str(micros(min(parent_times)))),
"bp": "e"
"id":
(info["worker_id"] + str(micros(min(parent_times)))),
"bp":
"e"
}
full_trace.append(task_trace)
@ -775,8 +861,8 @@ class GlobalState(object):
seen_obj[arg] += 1
owner_task = self._object_table(arg)["TaskID"]
if owner_task in task_info:
owner_worker = (workers[
task_info[owner_task]["worker_id"]])
owner_worker = (workers[task_info[owner_task][
"worker_id"]])
# Adding/subtracting 2 to the time associated
# with the beginning/ending of the flow event
# is necessary to make the flow events show up
@ -790,27 +876,35 @@ class GlobalState(object):
# duration event that it's associated with, and
# the flow event therefore always gets drawn.
owner = {
"cat": "obj_dependency",
"cat":
"obj_dependency",
"pid": ("Node " +
owner_worker["node_ip_address"]),
"tid": task_info[owner_task]["worker_id"],
"ts": micros_rel(task_info[
owner_task]["store_outputs_end"]) - 2,
"ph": "s",
"name": "ObjectDependency",
"tid":
task_info[owner_task]["worker_id"],
"ts":
micros_rel(task_info[owner_task]
["store_outputs_end"]) - 2,
"ph":
"s",
"name":
"ObjectDependency",
"args": {},
"bp": "e",
"cname": "cq_build_attempt_failed",
"id": "obj" + str(arg) + str(seen_obj[arg])
"bp":
"e",
"cname":
"cq_build_attempt_failed",
"id":
"obj" + str(arg) + str(seen_obj[arg])
}
full_trace.append(owner)
dependent = {
"cat": "obj_dependency",
"pid": "Node " + worker["node_ip_address"],
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(
info["get_arguments_start"]) + 2,
"ts":
micros_rel(info["get_arguments_start"]) + 2,
"ph": "f",
"name": "ObjectDependency",
"args": {},
@ -852,14 +946,10 @@ class GlobalState(object):
"""
keys = [
"acquire_lock_start",
"acquire_lock_end",
"get_arguments_start",
"get_arguments_end",
"execute_start",
"execute_end",
"store_outputs_start",
"store_outputs_end"]
"acquire_lock_start", "acquire_lock_end", "get_arguments_start",
"get_arguments_end", "execute_start", "execute_end",
"store_outputs_start", "store_outputs_end"
]
latest_timestamp = 0
for key in keys:
@ -877,8 +967,8 @@ class GlobalState(object):
local_schedulers = []
for ip_address, client_list in clients.items():
for client in client_list:
if (client["ClientType"] == "local_scheduler" and
not client["Deleted"]):
if (client["ClientType"] == "local_scheduler"
and not client["Deleted"]):
local_schedulers.append(client)
return local_schedulers
@ -893,8 +983,7 @@ class GlobalState(object):
workers_data[worker_id] = {
"local_scheduler_socket":
(worker_info[b"local_scheduler_socket"]
.decode("ascii")),
(worker_info[b"local_scheduler_socket"].decode("ascii")),
"node_ip_address": (worker_info[b"node_ip_address"]
.decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
@ -921,9 +1010,10 @@ class GlobalState(object):
"class_id": binary_to_hex(info[b"class_id"]),
"driver_id": binary_to_hex(info[b"driver_id"]),
"local_scheduler_id":
binary_to_hex(info[b"local_scheduler_id"]),
binary_to_hex(info[b"local_scheduler_id"]),
"num_gpus": int(info[b"num_gpus"]),
"removed": decode(info[b"removed"]) == "True"}
"removed": decode(info[b"removed"]) == "True"
}
return actor_info
def _job_length(self):
@ -932,21 +1022,16 @@ class GlobalState(object):
overall_largest = 0
num_tasks = 0
for event_log_set in event_log_sets:
fwd_range = self.redis_client.zrange(event_log_set,
start=0,
end=0,
withscores=True)
fwd_range = self.redis_client.zrange(
event_log_set, start=0, end=0, withscores=True)
overall_smallest = min(overall_smallest, fwd_range[0][1])
rev_range = self.redis_client.zrevrange(event_log_set,
start=0,
end=0,
withscores=True)
rev_range = self.redis_client.zrevrange(
event_log_set, start=0, end=0, withscores=True)
overall_largest = max(overall_largest, rev_range[0][1])
num_tasks += self.redis_client.zcount(event_log_set,
min=0,
max=time.time())
num_tasks += self.redis_client.zcount(
event_log_set, min=0, max=time.time())
if num_tasks is 0:
return 0, 0, 0
return overall_smallest, overall_largest, num_tasks
@ -966,8 +1051,10 @@ class GlobalState(object):
for local_scheduler in local_schedulers:
for key, value in local_scheduler.items():
if key not in ["ClientType", "Deleted", "DBClientID",
"AuxAddress", "LocalSchedulerSocketName"]:
if key not in [
"ClientType", "Deleted", "DBClientID", "AuxAddress",
"LocalSchedulerSocketName"
]:
resources[key] += value
return dict(resources)

View file

@ -27,6 +27,7 @@ class TensorFlowVariables(object):
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
"""
def __init__(self, loss, sess=None, input_variables=None):
"""Creates TensorFlowVariables containing extracted variables.
@ -74,8 +75,10 @@ class TensorFlowVariables(object):
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
variable_list = [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]
variable_list = [
v for v in tf.global_variables()
if v.op.node_def.name in variable_names
]
if input_variables is not None:
variable_list += input_variables
for v in variable_list:
@ -86,9 +89,10 @@ class TensorFlowVariables(object):
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.placeholders[k] = tf.placeholder(
var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.assignment_nodes[k] = var.assign(self.placeholders[k])
def set_session(self, sess):
@ -105,8 +109,9 @@ class TensorFlowVariables(object):
Returns:
The length of all flattened variables concatenated.
"""
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
return sum([
np.prod(v.get_shape().as_list()) for v in self.variables.values()
])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
@ -122,8 +127,10 @@ class TensorFlowVariables(object):
1D Array containing the flattened weights.
"""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
return np.concatenate([
v.eval(session=self.sess).flatten()
for v in self.variables.values()
])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array.
@ -138,10 +145,12 @@ class TensorFlowVariables(object):
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k] for k, v
in self.variables.items()]
self.sess.run(list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
placeholders = [
self.placeholders[k] for k, v in self.variables.items()
]
self.sess.run(
list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns a dictionary containing the weights of the network.
@ -150,8 +159,10 @@ class TensorFlowVariables(object):
Dictionary mapping variable names to their weights.
"""
self._check_sess()
return {k: v.eval(session=self.sess) for k, v
in self.variables.items()}
return {
k: v.eval(session=self.sess)
for k, v in self.variables.items()
}
def set_weights(self, new_weights):
"""Sets the weights to new_weights.
@ -165,15 +176,19 @@ class TensorFlowVariables(object):
weights.
"""
self._check_sess()
assign_list = [self.assignment_nodes[name]
for name in new_weights.keys()
if name in self.assignment_nodes]
assign_list = [
self.assignment_nodes[name] for name in new_weights.keys()
if name in self.assignment_nodes
]
assert assign_list, ("No variables in the input matched those in the "
"network. Possible cause: Two networks were "
"defined in the same TensorFlow graph. To fix "
"this, place each network definition in its own "
"tf.Graph.")
self.sess.run(assign_list,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
self.sess.run(
assign_list,
feed_dict={
self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders
})

View file

@ -29,9 +29,9 @@ class _EventRecursionContextManager(object):
total_time_value = "% total time"
total_tasks_value = "% total tasks"
# Function that returns instances of sliders and handles associated events.
def get_sliders(update):
# Start_box value indicates the desired start point of queried window.
start_box = widgets.FloatText(
@ -60,18 +60,14 @@ def get_sliders(update):
# Indicates the number of tasks that the user wants to be returned. Is
# disabled when the breakdown_opt value is set to total_time_value.
num_tasks_box = widgets.IntText(
description="Num Tasks:",
disabled=False
)
num_tasks_box = widgets.IntText(description="Num Tasks:", disabled=False)
# Dropdown bar that lets the user choose between modifying % of total
# time or total number of tasks.
breakdown_opt = widgets.Dropdown(
options=[total_time_value, total_tasks_value],
value=total_tasks_value,
description="Selection Options:"
)
description="Selection Options:")
# Display box for layout.
total_time_box = widgets.VBox([start_box, end_box])
@ -105,9 +101,9 @@ def get_sliders(update):
if event == INIT_EVENT:
if breakdown_opt.value == total_tasks_value:
num_tasks_box.value = -min(10000, num_tasks)
range_slider.value = (int(100 -
(100. * -num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (int(
100 - (100. * -num_tasks_box.value) / num_tasks),
100)
else:
low, high = map(lambda x: x / 100., range_slider.value)
start_box.value = round(diff * low, 2)
@ -120,8 +116,8 @@ def get_sliders(update):
elif start_box.value < 0:
start_box.value = 0
low, high = range_slider.value
range_slider.value = (int((start_box.value * 100.) /
diff), high)
range_slider.value = (int((start_box.value * 100.) / diff),
high)
# Event was triggered by a change in the end_box value.
elif event["owner"] == end_box:
@ -130,8 +126,8 @@ def get_sliders(update):
elif end_box.value > diff:
end_box.value = diff
low, high = range_slider.value
range_slider.value = (low, int((end_box.value * 100.) /
diff))
range_slider.value = (low,
int((end_box.value * 100.) / diff))
# Event was triggered by a change in the breakdown options
# toggle.
@ -145,9 +141,9 @@ def get_sliders(update):
# Make CSS display go back to the default settings.
num_tasks_box.layout.display = None
num_tasks_box.value = min(10000, num_tasks)
range_slider.value = (int(100 -
(100. * num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (int(
100 - (100. * num_tasks_box.value) / num_tasks),
100)
else:
start_box.disabled = False
end_box.disabled = False
@ -156,10 +152,9 @@ def get_sliders(update):
# Make CSS display go back to the default settings.
total_time_box.layout.display = None
num_tasks_box.layout.display = 'none'
range_slider.value = (int((start_box.value * 100.) /
diff),
int((end_box.value * 100.) /
diff))
range_slider.value = (
int((start_box.value * 100.) / diff),
int((end_box.value * 100.) / diff))
# Event was triggered by a change in the range_slider
# value.
@ -170,8 +165,8 @@ def get_sliders(update):
new_low, new_high = event["new"]
if old_low != new_low:
range_slider.value = (new_low, 100)
num_tasks_box.value = (-(100. - new_low) /
100. * num_tasks)
num_tasks_box.value = (
-(100. - new_low) / 100. * num_tasks)
else:
range_slider.value = (0, new_high)
num_tasks_box.value = new_high / 100. * num_tasks
@ -183,14 +178,12 @@ def get_sliders(update):
# value.
elif event["owner"] == num_tasks_box:
if num_tasks_box.value > 0:
range_slider.value = (0, int(100 *
float(num_tasks_box.value) /
num_tasks))
range_slider.value = (
0, int(
100 * float(num_tasks_box.value) / num_tasks))
elif num_tasks_box.value < 0:
range_slider.value = (100 +
int(100 *
float(num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (100 + int(
100 * float(num_tasks_box.value) / num_tasks), 100)
if not update:
return
@ -205,23 +198,20 @@ def get_sliders(update):
# box values.
# (Querying based on the % total amount of time.)
if breakdown_opt.value == total_time_value:
tasks = _truncated_task_profiles(start=(smallest +
diff * low),
end=(smallest +
diff * high))
tasks = _truncated_task_profiles(
start=(smallest + diff * low),
end=(smallest + diff * high))
# (Querying based on % of total number of tasks that were
# run.)
elif breakdown_opt.value == total_tasks_value:
if range_slider.value[0] == 0:
tasks = _truncated_task_profiles(num_tasks=(int(
num_tasks * high)),
fwd=True)
tasks = _truncated_task_profiles(
num_tasks=(int(num_tasks * high)), fwd=True)
else:
tasks = _truncated_task_profiles(num_tasks=(int(
num_tasks *
(high - low))),
fwd=False)
tasks = _truncated_task_profiles(
num_tasks=(int(num_tasks * (high - low))),
fwd=False)
update(smallest, largest, num_tasks, tasks)
@ -237,8 +227,8 @@ def get_sliders(update):
update_wrapper(INIT_EVENT)
# Display sliders and search boxes
display(breakdown_opt, widgets.HBox([range_slider, total_time_box,
num_tasks_box]))
display(breakdown_opt,
widgets.HBox([range_slider, total_time_box, num_tasks_box]))
# Return the sliders and text boxes
return start_box, end_box, range_slider, breakdown_opt
@ -249,8 +239,7 @@ def object_search_bar():
value="",
placeholder="Object ID",
description="Search for an object:",
disabled=False
)
disabled=False)
display(object_search)
def handle_submit(sender):
@ -265,8 +254,7 @@ def task_search_bar():
value="",
placeholder="Task ID",
description="Search for a task:",
disabled=False
)
disabled=False)
display(task_search)
def handle_submit(sender):
@ -284,14 +272,12 @@ MAX_TASKS_TO_VISUALIZE = 10000
def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
if num_tasks is None:
num_tasks = MAX_TASKS_TO_VISUALIZE
print(
"Warning: at most {} tasks will be fetched within this "
"time range.".format(MAX_TASKS_TO_VISUALIZE))
print("Warning: at most {} tasks will be fetched within this "
"time range.".format(MAX_TASKS_TO_VISUALIZE))
elif num_tasks > MAX_TASKS_TO_VISUALIZE:
print(
"Warning: too many tasks to visualize, "
"fetching only the first {} of {}.".format(
MAX_TASKS_TO_VISUALIZE, num_tasks))
print("Warning: too many tasks to visualize, "
"fetching only the first {} of {}.".format(
MAX_TASKS_TO_VISUALIZE, num_tasks))
num_tasks = MAX_TASKS_TO_VISUALIZE
return ray.global_state.task_profiles(num_tasks, start, end, fwd)
@ -299,9 +285,8 @@ def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
# Helper function that guarantees unique and writeable temp files.
# Prevents clashes in task trace files when multiple notebooks are running.
def _get_temp_file_path(**kwargs):
temp_file = tempfile.NamedTemporaryFile(delete=False,
dir=os.getcwd(),
**kwargs)
temp_file = tempfile.NamedTemporaryFile(
delete=False, dir=os.getcwd(), **kwargs)
temp_file_path = temp_file.name
temp_file.close()
return os.path.relpath(temp_file_path)
@ -319,22 +304,16 @@ def task_timeline():
disabled=False,
)
obj_dep = widgets.Checkbox(
value=True,
disabled=False,
layout=widgets.Layout(width='20px')
)
value=True, disabled=False, layout=widgets.Layout(width='20px'))
task_dep = widgets.Checkbox(
value=True,
disabled=False,
layout=widgets.Layout(width='20px')
)
value=True, disabled=False, layout=widgets.Layout(width='20px'))
# Labels to bypass width limitation for descriptions.
label_tasks = widgets.Label(value='Task submissions',
layout=widgets.Layout(width='110px'))
label_objects = widgets.Label(value='Object dependencies',
layout=widgets.Layout(width='130px'))
label_options = widgets.Label(value='View options:',
layout=widgets.Layout(width='100px'))
label_tasks = widgets.Label(
value='Task submissions', layout=widgets.Layout(width='110px'))
label_objects = widgets.Label(
value='Object dependencies', layout=widgets.Layout(width='130px'))
label_options = widgets.Label(
value='View options:', layout=widgets.Layout(width='100px'))
start_box, end_box, range_slider, time_opt = get_sliders(False)
display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects]))
display(widgets.HBox([label_options, breakdown_opt]))
@ -344,8 +323,9 @@ def task_timeline():
# current working directory if it is not present.
if not os.path.exists("trace_viewer_full.html"):
shutil.copy(
os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../core/src/catapult_files/trace_viewer_full.html"),
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../core/src/catapult_files/trace_viewer_full.html"),
"trace_viewer_full.html")
def handle_submit(sender):
@ -357,8 +337,8 @@ def task_timeline():
elif breakdown_opt.value == breakdown_task:
breakdown = True
else:
raise ValueError(
"Unexpected breakdown value '{}'".format(breakdown_opt.value))
raise ValueError("Unexpected breakdown value '{}'".format(
breakdown_opt.value))
low, high = map(lambda x: x / 100., range_slider.value)
@ -366,30 +346,28 @@ def task_timeline():
diff = largest - smallest
if time_opt.value == total_time_value:
tasks = _truncated_task_profiles(start=smallest + diff * low,
end=smallest + diff * high)
tasks = _truncated_task_profiles(
start=smallest + diff * low, end=smallest + diff * high)
elif time_opt.value == total_tasks_value:
if range_slider.value[0] == 0:
tasks = _truncated_task_profiles(num_tasks=int(
num_tasks * high),
fwd=True)
tasks = _truncated_task_profiles(
num_tasks=int(num_tasks * high), fwd=True)
else:
tasks = _truncated_task_profiles(num_tasks=int(
num_tasks * (high - low)),
fwd=False)
tasks = _truncated_task_profiles(
num_tasks=int(num_tasks * (high - low)), fwd=False)
else:
raise ValueError("Unexpected time value '{}'".format(
time_opt.value))
time_opt.value))
# Write trace to a JSON file
print("Collected profiles for {} tasks.".format(len(tasks)))
print(
"Dumping task profile data to {}, "
"this might take a while...".format(json_tmp))
ray.global_state.dump_catapult_trace(json_tmp,
tasks,
breakdowns=breakdown,
obj_dep=obj_dep.value,
task_dep=task_dep.value)
print("Dumping task profile data to {}, "
"this might take a while...".format(json_tmp))
ray.global_state.dump_catapult_trace(
json_tmp,
tasks,
breakdowns=breakdown,
obj_dep=obj_dep.value,
task_dep=task_dep.value)
print("Opening html file in browser...")
trace_viewer_path = os.path.join(
@ -415,9 +393,8 @@ def task_timeline():
# Display the task trace within the Jupyter notebook
clear_output(wait=True)
print(
"To view fullscreen, open chrome://tracing in Google Chrome "
"and load `{}`".format(json_tmp))
print("To view fullscreen, open chrome://tracing in Google Chrome "
"and load `{}`".format(json_tmp))
display(IFrame(html_file_path, 900, 800))
path_input.on_click(handle_submit)
@ -432,36 +409,41 @@ def task_completion_time_distribution():
output_notebook(resources=CDN)
# Create the Bokeh plot
p = figure(title="Task Completion Time Distribution",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=(0, 1),
y_range=(0, 1))
p = figure(
title="Task Completion Time Distribution",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=(0, 1),
y_range=(0, 1))
# Create the data source that the plot pulls from
source = ColumnDataSource(data={
"top": [],
"left": [],
"right": []
})
source = ColumnDataSource(data={"top": [], "left": [], "right": []})
# Plot the histogram rectangles
p.quad(top="top", bottom=0, left="left", right="right", source=source,
fill_color="#B3B3B3", line_color="#033649")
p.quad(
top="top",
bottom=0,
left="left",
right="right",
source=source,
fill_color="#B3B3B3",
line_color="#033649")
# Label the plot axes
p.xaxis.axis_label = "Duration in seconds"
p.yaxis.axis_label = "Number of tasks"
handle = show(gridplot(p, ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"), notebook_handle=True)
handle = show(
gridplot(
p,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"),
notebook_handle=True)
# Function to update the plot
def task_completion_time_update(abs_earliest,
abs_latest,
abs_num_tasks,
def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks,
tasks):
if len(tasks) == 0:
return
@ -469,8 +451,8 @@ def task_completion_time_distribution():
# Create the distribution to plot
distr = []
for task_id, data in tasks.items():
distr.append(data["store_outputs_end"] -
data["get_arguments_start"])
distr.append(
data["store_outputs_end"] - data["get_arguments_start"])
# Create a histogram from the distribution
top, bin_edges = np.histogram(distr, bins="auto")
@ -480,8 +462,8 @@ def task_completion_time_distribution():
source.data = {"top": top, "left": left, "right": right}
# Set the x and y ranges
x_range = (min(left) if len(left) else 0,
max(right) if len(right) else 1)
x_range = (min(left) if len(left) else 0, max(right)
if len(right) else 1)
y_range = (0, max(top) + 1 if len(top) else 1)
x_range = helpers._get_range(x_range)
@ -517,8 +499,7 @@ def compute_utilizations(abs_earliest,
latest_time = 0
for task_id, data in tasks.items():
latest_time = max((latest_time, data["store_outputs_end"]))
earliest_time = min((earliest_time,
data["get_arguments_start"]))
earliest_time = min((earliest_time, data["get_arguments_start"]))
# Add some epsilon to latest_time to ensure that the end time of the
# last task falls __within__ a bucket, and not on the edge
@ -533,37 +514,37 @@ def compute_utilizations(abs_earliest,
task_start_time = data["get_arguments_start"]
task_end_time = data["store_outputs_end"]
start_bucket = int((task_start_time - earliest_time) /
bucket_time_length)
end_bucket = int((task_end_time - earliest_time) /
bucket_time_length)
start_bucket = int(
(task_start_time - earliest_time) / bucket_time_length)
end_bucket = int((task_end_time - earliest_time) / bucket_time_length)
# Walk over each time bucket that this task intersects, adding the
# amount of time that the task intersects within each bucket
for bucket_idx in range(start_bucket, end_bucket + 1):
bucket_start_time = ((earliest_time + bucket_idx) *
bucket_time_length)
bucket_end_time = ((earliest_time + (bucket_idx + 1)) *
bucket_time_length)
bucket_start_time = ((
earliest_time + bucket_idx) * bucket_time_length)
bucket_end_time = ((earliest_time +
(bucket_idx + 1)) * bucket_time_length)
task_start_time_within_bucket = max(task_start_time,
bucket_start_time)
task_end_time_within_bucket = min(task_end_time,
bucket_end_time)
task_cpu_time_within_bucket = (task_end_time_within_bucket -
task_start_time_within_bucket)
task_end_time_within_bucket = min(task_end_time, bucket_end_time)
task_cpu_time_within_bucket = (
task_end_time_within_bucket - task_start_time_within_bucket)
if bucket_idx > -1 and bucket_idx < num_buckets:
cpu_time[bucket_idx] += task_cpu_time_within_bucket
# Cpu_utilization is the average cpu utilization of the bucket, which
# is just cpu_time divided by bucket_time_length.
cpu_utilization = list(map(lambda x: x / float(bucket_time_length),
cpu_time))
cpu_utilization = list(
map(lambda x: x / float(bucket_time_length), cpu_time))
# Generate histogram bucket edges. Subtract out abs_earliest to get
# relative time.
all_edges = [earliest_time - abs_earliest + i * bucket_time_length
for i in range(num_buckets + 1)]
all_edges = [
earliest_time - abs_earliest + i * bucket_time_length
for i in range(num_buckets + 1)
]
# Left edges are all but the rightmost edge, right edges are all but
# the leftmost edge.
left_edges = all_edges[:-1]
@ -591,54 +572,53 @@ def cpu_usage():
# Update the plot based on the sliders
def plot_utilization():
# Create the Bokeh plot
time_series_fig = figure(title="CPU Utilization",
tools=["save", "hover", "wheel_zoom",
"box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=[0, 1],
y_range=[0, 1])
time_series_fig = figure(
title="CPU Utilization",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=[0, 1],
y_range=[0, 1])
# Create the data source that the plot will pull from
time_series_source = ColumnDataSource(data=dict(
left=[],
right=[],
top=[]
))
time_series_source = ColumnDataSource(
data=dict(left=[], right=[], top=[]))
# Plot the rectangles representing the distribution
time_series_fig.quad(left="left",
right="right",
top="top",
bottom=0,
source=time_series_source,
fill_color="#B3B3B3",
line_color="#033649")
time_series_fig.quad(
left="left",
right="right",
top="top",
bottom=0,
source=time_series_source,
fill_color="#B3B3B3",
line_color="#033649")
# Label the plot axes
time_series_fig.xaxis.axis_label = "Time in seconds"
time_series_fig.yaxis.axis_label = "Number of CPUs used"
handle = show(gridplot(time_series_fig,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"), notebook_handle=True)
handle = show(
gridplot(
time_series_fig,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"),
notebook_handle=True)
def update_plot(abs_earliest, abs_latest, abs_num_tasks, tasks):
num_buckets = 100
left, right, top = compute_utilizations(abs_earliest,
abs_latest,
abs_num_tasks,
tasks,
num_buckets)
left, right, top = compute_utilizations(
abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets)
time_series_source.data = {"left": left,
"right": right,
"top": top}
time_series_source.data = {
"left": left,
"right": right,
"top": top
}
x_range = (max(0, min(left))
if len(left) else 0,
max(right) if len(right) else 1)
x_range = (max(0, min(left)) if len(left) else 0, max(right)
if len(right) else 1)
y_range = (0, max(top) + 1 if len(top) else 1)
# Define the axis ranges
@ -654,6 +634,7 @@ def cpu_usage():
push_notebook(handle=handle)
get_sliders(update_plot)
plot_utilization()
@ -672,33 +653,32 @@ def cluster_usage():
output_notebook(resources=CDN)
# Initial values
source = ColumnDataSource(data={"node_ip_address": ['127.0.0.1'],
"time": ['0.5'],
"num_tasks": ['1'],
"length": [1]})
source = ColumnDataSource(
data={
"node_ip_address": ['127.0.0.1'],
"time": ['0.5'],
"num_tasks": ['1'],
"length": [1]
})
# Define the color schema
colors = ["#75968f",
"#a5bab7",
"#c9d9d3",
"#e2e2e2",
"#dfccce",
"#ddb7b1",
"#cc7878",
"#933b41",
"#550b1d"]
colors = [
"#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1",
"#cc7878", "#933b41", "#550b1d"
]
mapper = LinearColorMapper(palette=colors, low=0, high=2)
TOOLS = "hover, save, xpan, box_zoom, reset, xwheel_zoom"
# Create the plot
p = figure(title="Cluster Usage",
y_range=list(set(source.data['node_ip_address'])),
x_axis_location="above",
plot_width=900,
plot_height=500,
tools=TOOLS,
toolbar_location='below')
p = figure(
title="Cluster Usage",
y_range=list(set(source.data['node_ip_address'])),
x_axis_location="above",
plot_width=900,
plot_height=500,
tools=TOOLS,
toolbar_location='below')
# Format the plot axes
p.grid.grid_line_color = None
@ -709,26 +689,33 @@ def cluster_usage():
p.xaxis.major_label_orientation = np.pi / 3
# Plot rectangles
p.rect(x="time", y="node_ip_address", width="length", height=1,
source=source,
fill_color={"field": "num_tasks", "transform": mapper},
line_color=None)
p.rect(
x="time",
y="node_ip_address",
width="length",
height=1,
source=source,
fill_color={
"field": "num_tasks",
"transform": mapper
},
line_color=None)
# Add legend to the side of the plot
color_bar = ColorBar(color_mapper=mapper,
major_label_text_font_size="8pt",
ticker=BasicTicker(desired_num_ticks=len(colors)),
label_standoff=6,
border_line_color=None,
location=(0, 0))
color_bar = ColorBar(
color_mapper=mapper,
major_label_text_font_size="8pt",
ticker=BasicTicker(desired_num_ticks=len(colors)),
label_standoff=6,
border_line_color=None,
location=(0, 0))
p.add_layout(color_bar, "right")
# Define hover tool
p.select_one(HoverTool).tooltips = [
("Node IP Address", "@node_ip_address"),
("Number of tasks running", "@num_tasks"),
("Time", "@time")
]
p.select_one(HoverTool).tooltips = [("Node IP Address",
"@node_ip_address"),
("Number of tasks running",
"@num_tasks"), ("Time", "@time")]
# Define the axis labels
p.xaxis.axis_label = "Time in seconds"
@ -764,12 +751,8 @@ def cluster_usage():
num_tasks = []
for node_ip, task_dict in node_to_tasks.items():
left, right, top = compute_utilizations(earliest,
latest,
abs_num_tasks,
task_dict,
100,
True)
left, right, top = compute_utilizations(
earliest, latest, abs_num_tasks, task_dict, 100, True)
for (l, r, t) in zip(left, right, top):
nodes.append(node_ip)
times.append((l + r) / 2)
@ -783,10 +766,12 @@ def cluster_usage():
mapper.high = max(max(num_tasks), 1)
# Update plot with new data based on slider and text box values
source.data = {"node_ip_address": nodes,
"time": times,
"num_tasks": num_tasks,
"length": lengths}
source.data = {
"node_ip_address": nodes,
"time": times,
"num_tasks": num_tasks,
"length": lengths
}
push_notebook(handle=handle)

View file

@ -7,9 +7,12 @@ import subprocess
import time
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
def start_global_scheduler(redis_address,
node_ip_address,
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a global scheduler process.
Args:
@ -33,21 +36,24 @@ def start_global_scheduler(redis_address, node_ip_address,
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
command = [
global_scheduler_executable, "-r", redis_address, "-h", node_ip_address
]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)

View file

@ -56,7 +56,6 @@ def new_port():
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
self.node_ip_address = "127.0.0.1"
@ -164,17 +163,19 @@ class TestGlobalScheduler(unittest.TestCase):
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0)
task1 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0)
self.assertEqual(task1.required_resources(), {"CPU": 1})
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_OBJECT_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
0, 0, [], {"CPU": 1, "GPU": 2})
task2 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_OBJECT_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], {
"CPU": 1,
"GPU": 2
})
self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2})
def test_redis_only_single_task(self):
@ -189,7 +190,7 @@ class TestGlobalScheduler(unittest.TestCase):
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
assert (db_client_id is not None)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
@ -227,9 +228,10 @@ class TestGlobalScheduler(unittest.TestCase):
if len(task_entries) == 1:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
self.assertTrue(task_status in [
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
@ -258,17 +260,14 @@ class TestGlobalScheduler(unittest.TestCase):
data_size = np.random.randint(1 << 12)
metadata_size = np.random.randint(1 << 9)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to
# yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(
random_driver_id(),
random_function_id(),
random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep.binary())],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
@ -281,12 +280,18 @@ class TestGlobalScheduler(unittest.TestCase):
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
task_statuses = [
task_entry["State"]
for task_entry in task_entries.values()
]
self.assertTrue(
all([
status in [
state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
] for status in task_statuses
]))
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(
state.TASK_STATUS_SCHEDULED)
@ -294,12 +299,13 @@ class TestGlobalScheduler(unittest.TestCase):
state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, "
"tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done,
num_retries))
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
"tasks queued = {}, retries left = {}".format(
len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([
status == state.TASK_STATUS_QUEUED
for status in task_statuses
]):
# We're done, so pass.
break
num_retries -= 1

View file

@ -7,6 +7,8 @@ from ray.core.src.local_scheduler.liblocal_scheduler_library import (
task_to_string, _config, common_error)
from .local_scheduler_services import start_local_scheduler
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler",
"_config", "common_error"]
__all__ = [
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
"common_error"
]

View file

@ -68,15 +68,15 @@ def start_local_scheduler(plasma_store_name,
"provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
local_scheduler_executable = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
"-p", plasma_store_name,
"-h", node_ip_address,
"-n", str(num_workers)]
command = [
local_scheduler_executable, "-s", local_scheduler_name, "-p",
plasma_store_name, "-h", node_ip_address, "-n",
str(num_workers)
]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if worker_path is not None:
@ -88,14 +88,11 @@ def start_local_scheduler(plasma_store_name,
"--object-store-name={} "
"--object-store-manager-name={} "
"--local-scheduler-name={} "
"--redis-address={}"
.format(sys.executable,
worker_path,
node_ip_address,
plasma_store_name,
plasma_manager_name,
local_scheduler_name,
redis_address))
"--redis-address={}".format(
sys.executable, worker_path,
node_ip_address, plasma_store_name,
plasma_manager_name, local_scheduler_name,
redis_address))
command += ["-w", start_worker_command]
if redis_address is not None:
command += ["-r", redis_address]
@ -104,27 +101,31 @@ def start_local_scheduler(plasma_store_name,
if static_resources is not None:
resource_argument = ""
for resource_name, resource_quantity in static_resources.items():
assert (isinstance(resource_quantity, int) or
isinstance(resource_quantity, float))
resource_argument = ",".join(
[resource_name + "," + str(resource_quantity)
for resource_name, resource_quantity in static_resources.items()])
assert (isinstance(resource_quantity, int)
or isinstance(resource_quantity, float))
resource_argument = ",".join([
resource_name + "," + str(resource_quantity)
for resource_name, resource_quantity in static_resources.items()
])
else:
resource_argument = "CPU,{}".format(psutil.cpu_count())
command += ["-c", resource_argument]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)

View file

@ -37,7 +37,6 @@ def random_function_id():
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
@ -74,34 +73,17 @@ class TestLocalSchedulerClient(unittest.TestCase):
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
# Define some arguments to use for the tasks.
args_list = [
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1],
1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
@ -146,6 +128,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
@ -170,6 +153,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()

View file

@ -26,11 +26,12 @@ class LogMonitor(object):
log_file_handles: A dictionary mapping the name of a log file to a file
handle for that file.
"""
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
self.log_files = {}
self.log_file_handles = {}
self.files_to_ignore = set()
@ -38,9 +39,8 @@ class LogMonitor(object):
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(
self.node_ip_address), num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
@ -78,8 +78,8 @@ class LogMonitor(object):
# Try to open this file for the first time.
else:
try:
self.log_file_handles[log_filename] = open(log_filename,
"r")
self.log_file_handles[log_filename] = open(
log_filename, "r")
except IOError as e:
if e.errno == os.errno.EMFILE:
print("Warning: Ignoring {} because there are too "
@ -106,13 +106,20 @@ class LogMonitor(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument("--redis-address", required=True, type=str,
help="The address to use for Redis.")
parser.add_argument("--node-ip-address", required=True, type=str,
help="The IP address of the node this process is on.")
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
parser.add_argument(
"--node-ip-address",
required=True,
type=str,
help="The IP address of the node this process is on.")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)

View file

@ -100,8 +100,8 @@ class Monitor(object):
self.local_scheduler_id_to_ip_map = dict()
self.load_metrics = LoadMetrics()
if autoscaling_config:
self.autoscaler = StandardAutoscaler(
autoscaling_config, self.load_metrics)
self.autoscaler = StandardAutoscaler(autoscaling_config,
self.load_metrics)
else:
self.autoscaler = None
@ -160,11 +160,9 @@ class Monitor(object):
# task as lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE",
hex_to_binary(task_id),
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
task["ExecutionDependenciesString"],
task["SpillbackCount"])
task["ExecutionDependenciesString"], task["SpillbackCount"])
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
@ -428,8 +426,8 @@ class Monitor(object):
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info(
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
log.info("Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._clean_up_entries_for_driver(driver_id)
@ -571,8 +569,9 @@ class Monitor(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"monitor to connect to."))
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument(
"--redis-address",
required=True,

View file

@ -5,5 +5,6 @@ from __future__ import print_function
from ray.plasma.plasma import (start_plasma_store, start_plasma_manager,
DEFAULT_PLASMA_STORE_MEMORY)
__all__ = ["start_plasma_store", "start_plasma_manager",
"DEFAULT_PLASMA_STORE_MEMORY"]
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]

View file

@ -8,13 +8,13 @@ import subprocess
import sys
import time
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]
__all__ = ["start_plasma_store", "start_plasma_manager",
"DEFAULT_PLASMA_STORE_MEMORY"]
PLASMA_WAIT_TIMEOUT = 2**30
PLASMA_WAIT_TIMEOUT = 2 ** 30
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
DEFAULT_PLASMA_STORE_MEMORY = 10**9
def random_name():
@ -22,9 +22,12 @@ def random_name():
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None,
plasma_directory=None, huge_pages=False):
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None,
plasma_directory=None,
huge_pages=False):
"""Start a plasma store process.
Args:
@ -48,8 +51,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
if huge_pages and not (sys.platform == "linux" or
sys.platform == "linux2"):
if huge_pages and not (sys.platform == "linux"
or sys.platform == "linux2"):
raise Exception("The huge_pages argument is only supported on "
"Linux.")
@ -57,29 +60,33 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
raise Exception("If huge_pages is True, then the "
"plasma_directory argument must be provided.")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
plasma_store_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
command = [
plasma_store_executable, "-s", plasma_store_name, "-m",
str(plasma_store_memory)
]
if plasma_directory is not None:
command += ["-d", plasma_directory]
if huge_pages:
command += ["-h"]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@ -91,10 +98,14 @@ def new_port():
return random.randint(10000, 65535)
def start_plasma_manager(store_name, redis_address,
node_ip_address="127.0.0.1", plasma_manager_port=None,
num_retries=20, use_valgrind=False,
run_profiler=False, stdout_file=None,
def start_plasma_manager(store_name,
redis_address,
node_ip_address="127.0.0.1",
plasma_manager_port=None,
num_retries=20,
use_valgrind=False,
run_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
@ -133,27 +144,35 @@ def start_plasma_manager(store_name, redis_address,
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
command = [plasma_manager_executable,
"-s", store_name,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address,
]
command = [
plasma_manager_executable,
"-s",
store_name,
"-m",
plasma_manager_name,
"-h",
node_ip_address,
"-p",
str(plasma_manager_port),
"-r",
redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
process = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
command),
stdout=stdout_file, stderr=stderr_file)
process = subprocess.Popen(
(["valgrind", "--tool=callgrind"] + command),
stdout=stdout_file,
stderr=stderr_file)
else:
process = subprocess.Popen(command, stdout=stdout_file,
stderr=stderr_file)
process = subprocess.Popen(
command, stdout=stdout_file, stderr=stderr_file)
# This sleep is critical. If the plasma_manager fails to start because
# the port is already in use, then we need it to fail within 0.1
# seconds.

View file

@ -16,8 +16,8 @@ import unittest
# The ray import must come before the pyarrow import because ray modifies the
# python path so that the right version of pyarrow is found.
import ray
from ray.plasma.utils import (random_object_id,
create_object_with_id, create_object)
from ray.plasma.utils import (random_object_id, create_object_with_id,
create_object)
from ray import services
import pyarrow as pa
import pyarrow.plasma as plasma
@ -30,8 +30,12 @@ def random_name():
return str(random.randint(0, 99999999))
def assert_get_object_equal(unit_test, client1, client2, object_id,
memory_buffer=None, metadata=None):
def assert_get_object_equal(unit_test,
client1,
client2,
object_id,
memory_buffer=None,
metadata=None):
client1_buff = client1.get_buffers([object_id])[0]
client2_buff = client2.get_buffers([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
@ -39,27 +43,33 @@ def assert_get_object_equal(unit_test, client1, client2, object_id,
unit_test.assertEqual(len(client1_buff), len(client2_buff))
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
# Check that the buffers from the two clients are the same.
assert_equal(np.frombuffer(client1_buff, dtype="uint8"),
np.frombuffer(client2_buff, dtype="uint8"))
assert_equal(
np.frombuffer(client1_buff, dtype="uint8"),
np.frombuffer(client2_buff, dtype="uint8"))
# Check that the metadata buffers from the two clients are the same.
assert_equal(np.frombuffer(client1_metadata, dtype="uint8"),
np.frombuffer(client2_metadata, dtype="uint8"))
assert_equal(
np.frombuffer(client1_metadata, dtype="uint8"),
np.frombuffer(client2_metadata, dtype="uint8"))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
assert_equal(np.frombuffer(memory_buffer, dtype="uint8"),
np.frombuffer(client1_buff, dtype="uint8"))
assert_equal(
np.frombuffer(memory_buffer, dtype="uint8"),
np.frombuffer(client1_buff, dtype="uint8"))
# If reference metadata was provided, check that it is the same as well.
if metadata is not None:
assert_equal(np.frombuffer(metadata, dtype="uint8"),
np.frombuffer(client1_metadata, dtype="uint8"))
assert_equal(
np.frombuffer(metadata, dtype="uint8"),
np.frombuffer(client1_metadata, dtype="uint8"))
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
DEFAULT_PLASMA_STORE_MEMORY = 10**9
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma store process.
Args:
use_valgrind (bool): True if the plasma store should be started inside
@ -78,21 +88,25 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(pa.__path__[0], "plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
command = [
plasma_store_executable, "-s", plasma_store_name, "-m",
str(plasma_store_memory)
]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@ -104,13 +118,10 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
class TestPlasmaManager(unittest.TestCase):
def setUp(self):
# Start two PlasmaStores.
store_name1, self.p2 = start_plasma_store(
use_valgrind=USE_VALGRIND)
store_name2, self.p3 = start_plasma_store(
use_valgrind=USE_VALGRIND)
store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_address, _ = services.start_redis("127.0.0.1")
# Start two PlasmaManagers.
@ -152,9 +163,8 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch(self):
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
@ -164,18 +174,20 @@ class TestPlasmaManager(unittest.TestCase):
while not self.client2.contains(object_id1):
self.client2.fetch([object_id1])
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# Test that we can call fetch on object IDs that don't exist yet.
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(self.client2,
object_id2,
2000, 2000)
memory_buffer2, metadata2 = create_object_with_id(
self.client2, object_id2, 2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
@ -190,68 +202,88 @@ class TestPlasmaManager(unittest.TestCase):
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(self.client1,
object_id3,
2000, 2000)
memory_buffer3, metadata3 = create_object_with_id(
self.client1, object_id3, 2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(self, self.client1, self.client2, object_id3,
memory_buffer=memory_buffer3,
metadata=metadata3)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id3,
memory_buffer=memory_buffer3,
metadata=metadata3)
def test_fetch_multiple(self):
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(self.client1,
2000,
2000)
object_id2, memory_buffer2, metadata2 = create_object(
self.client1, 2000, 2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object
# ID should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be
# updated.
while ((not self.client2.contains(object_id1)) or
(not self.client2.contains(object_id2))):
while ((not self.client2.contains(object_id1))
or (not self.client2.contains(object_id2))):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(self, self.client1, self.client2,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(self, self.client1, self.client2,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Fetch in the other direction. The fake object still does not
# exist.
self.client1.fetch(object_ids)
assert_get_object_equal(self, self.client2, self.client1,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(self, self.client2, self.client1,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(self.client1,
2000,
2000)
object_id4, memory_buffer4, metadata4 = create_object(
self.client1, 2000, 2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
self.client2.fetch([object_id3, object_id3, object_id4,
object_id4])
assert_get_object_equal(self, self.client2, self.client1, object_id4,
memory_buffer=memory_buffer4,
metadata=metadata4)
self.client2.fetch(
[object_id3, object_id3, object_id4, object_id4])
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id4,
memory_buffer=memory_buffer4,
metadata=metadata4)
def test_wait(self):
# Test timeout.
@ -263,8 +295,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id1 = random_object_id()
self.client1.create(obj_id1, 1000)
self.client1.seal(obj_id1)
ready, waiting = self.client1.wait([obj_id1], timeout=100,
num_returns=1)
ready, waiting = self.client1.wait(
[obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(waiting, [])
@ -273,8 +305,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
num_returns=1)
ready, waiting = self.client1.wait(
[obj_id2, obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(set(waiting), set([obj_id2]))
@ -287,8 +319,8 @@ class TestPlasmaManager(unittest.TestCase):
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
timeout=1000, num_returns=2)
ready, waiting = self.client1.wait(
[obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
@ -319,26 +351,26 @@ class TestPlasmaManager(unittest.TestCase):
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client1.wait(waiting, timeout=1000,
num_returns=i)
ready, waiting = self.client1.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
ready, waiting = self.client1.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client2.wait(waiting, timeout=1000,
num_returns=i)
ready, waiting = self.client2.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
ready, waiting = self.client2.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
@ -363,9 +395,8 @@ class TestPlasmaManager(unittest.TestCase):
num_attempts = 100
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
@ -379,9 +410,13 @@ class TestPlasmaManager(unittest.TestCase):
del buff
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# # Transfer the buffer again.
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
@ -391,8 +426,8 @@ class TestPlasmaManager(unittest.TestCase):
# metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
20000, 20000)
object_id2, memory_buffer2, metadata2 = create_object(
self.client2, 20000, 20000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
@ -406,9 +441,13 @@ class TestPlasmaManager(unittest.TestCase):
del buff
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
def test_illegal_functionality(self):
# Create an object id string.
@ -437,7 +476,6 @@ class TestPlasmaManager(unittest.TestCase):
class TestPlasmaManagerRecovery(unittest.TestCase):
def setUp(self):
# Start a Plasma store.
self.store_name, self.p2 = start_plasma_store(
@ -446,9 +484,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.redis_address, _ = services.start_redis("127.0.0.1")
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager(
self.store_name,
self.redis_address,
use_valgrind=USE_VALGRIND)
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
# Connect a PlasmaClient.
self.client = plasma.connect(self.store_name, manager_name, 64)
@ -501,8 +537,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
client2 = plasma.connect(self.store_name, manager_name, 64)
ready, waiting = [], object_ids
while True:
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
timeout=0)
ready, waiting = client2.wait(
object_ids, num_returns=num_objects, timeout=0)
if len(ready) == len(object_ids):
break

View file

@ -18,8 +18,8 @@ def generate_metadata(length):
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = (
random.randint(0, 255))
metadata_buffer[random.randint(0, length - 1)] = (random.randint(
0, 255))
return metadata_buffer
@ -32,7 +32,10 @@ def write_to_data_buffer(buff, length):
array[random.randint(0, length - 1)] = random.randint(0, 255)
def create_object_with_id(client, object_id, data_size, metadata_size,
def create_object_with_id(client,
object_id,
data_size,
metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
@ -44,7 +47,6 @@ def create_object_with_id(client, object_id, data_size, metadata_size,
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id,
data_size, metadata_size,
seal=seal)
memory_buffer, metadata = create_object_with_id(
client, object_id, data_size, metadata_size, seal=seal)
return object_id, memory_buffer, metadata

View file

@ -1,10 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Ray constants used in the Python code."""
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
AUTOSCALER_MAX_NUM_FAILURES = 5

View file

@ -7,8 +7,8 @@ import json
import subprocess
import ray.services as services
from ray.autoscaler.commands import (
create_or_update_cluster, teardown_cluster, get_head_node_ip)
from ray.autoscaler.commands import (create_or_update_cluster,
teardown_cluster, get_head_node_ip)
def check_no_existing_redis_clients(node_ip_address, redis_client):
@ -41,58 +41,116 @@ def cli():
@click.command()
@click.option("--node-ip-address", required=False, type=str,
help="the IP address of this node")
@click.option("--redis-address", required=False, type=str,
help="the address to use for connecting to Redis")
@click.option("--redis-port", required=False, type=str,
help="the port to use for starting Redis")
@click.option("--num-redis-shards", required=False, type=int,
help=("the number of additional Redis shards to use in "
"addition to the primary Redis shard"))
@click.option("--redis-max-clients", required=False, type=int,
help=("If provided, attempt to configure Redis with this "
"maximum number of clients."))
@click.option("--redis-shard-ports", required=False, type=str,
help="the port to use for the Redis shards other than the "
"primary Redis shard")
@click.option("--object-manager-port", required=False, type=int,
help="the port to use for starting the object manager")
@click.option("--object-store-memory", required=False, type=int,
help="the maximum amount of memory (in bytes) to allow the "
"object store to use")
@click.option("--num-workers", required=False, type=int,
help=("The initial number of workers to start on this node, "
"note that the local scheduler may start additional "
"workers. If you wish to control the total number of "
"concurent tasks, then use --resources instead and "
"specify the CPU field."))
@click.option("--num-cpus", required=False, type=int,
help="the number of CPUs on this node")
@click.option("--num-gpus", required=False, type=int,
help="the number of GPUs on this node")
@click.option("--resources", required=False, default="{}", type=str,
help="a JSON serialized dictionary mapping resource name to "
"resource quantity")
@click.option("--head", is_flag=True, default=False,
help="provide this argument for the head node")
@click.option("--no-ui", is_flag=True, default=False,
help="provide this argument if the UI should not be started")
@click.option("--block", is_flag=True, default=False,
help="provide this argument to block forever in this command")
@click.option("--plasma-directory", required=False, type=str,
help="object store directory for memory mapped files")
@click.option("--huge-pages", is_flag=True, default=False,
help="enable support for huge pages in the object store")
@click.option("--autoscaling-config", required=False, type=str,
help="the file that contains the autoscaling config")
@click.option("--use-raylet", is_flag=True, default=False,
help="use the raylet code path, this is not supported yet")
@click.option(
"--node-ip-address",
required=False,
type=str,
help="the IP address of this node")
@click.option(
"--redis-address",
required=False,
type=str,
help="the address to use for connecting to Redis")
@click.option(
"--redis-port",
required=False,
type=str,
help="the port to use for starting Redis")
@click.option(
"--num-redis-shards",
required=False,
type=int,
help=("the number of additional Redis shards to use in "
"addition to the primary Redis shard"))
@click.option(
"--redis-max-clients",
required=False,
type=int,
help=("If provided, attempt to configure Redis with this "
"maximum number of clients."))
@click.option(
"--redis-shard-ports",
required=False,
type=str,
help="the port to use for the Redis shards other than the "
"primary Redis shard")
@click.option(
"--object-manager-port",
required=False,
type=int,
help="the port to use for starting the object manager")
@click.option(
"--object-store-memory",
required=False,
type=int,
help="the maximum amount of memory (in bytes) to allow the "
"object store to use")
@click.option(
"--num-workers",
required=False,
type=int,
help=("The initial number of workers to start on this node, "
"note that the local scheduler may start additional "
"workers. If you wish to control the total number of "
"concurent tasks, then use --resources instead and "
"specify the CPU field."))
@click.option(
"--num-cpus",
required=False,
type=int,
help="the number of CPUs on this node")
@click.option(
"--num-gpus",
required=False,
type=int,
help="the number of GPUs on this node")
@click.option(
"--resources",
required=False,
default="{}",
type=str,
help="a JSON serialized dictionary mapping resource name to "
"resource quantity")
@click.option(
"--head",
is_flag=True,
default=False,
help="provide this argument for the head node")
@click.option(
"--no-ui",
is_flag=True,
default=False,
help="provide this argument if the UI should not be started")
@click.option(
"--block",
is_flag=True,
default=False,
help="provide this argument to block forever in this command")
@click.option(
"--plasma-directory",
required=False,
type=str,
help="object store directory for memory mapped files")
@click.option(
"--huge-pages",
is_flag=True,
default=False,
help="enable support for huge pages in the object store")
@click.option(
"--autoscaling-config",
required=False,
type=str,
help="the file that contains the autoscaling config")
@click.option(
"--use-raylet",
is_flag=True,
default=False,
help="use the raylet code path, this is not supported yet")
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, redis_shard_ports, object_manager_port,
object_store_memory, num_workers, num_cpus, num_gpus, resources,
head, no_ui, block, plasma_directory, huge_pages,
autoscaling_config, use_raylet):
head, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
use_raylet):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
@ -245,33 +303,54 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
@click.command()
def stop():
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
"local_scheduler raylet"], shell=True)
subprocess.call(
[
"killall global_scheduler plasma_store plasma_manager "
"local_scheduler raylet"
],
shell=True)
# Find the PID of the monitor process and kill it.
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the Redis process and kill it.
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PIDs of the worker processes and kill them.
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
shell=True)
subprocess.call(
[
"kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the Ray log monitor process and kill it.
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the jupyter process and kill it.
try:
from notebook.notebookapp import list_running_servers
pids = [str(server["pid"]) for server in list_running_servers()
if "/tmp/raylogs" in server["notebook_dir"]]
subprocess.call(["kill {} 2> /dev/null".format(
" ".join(pids))], shell=True)
pids = [
str(server["pid"]) for server in list_running_servers()
if "/tmp/raylogs" in server["notebook_dir"]
]
subprocess.call(
["kill {} 2> /dev/null".format(" ".join(pids))], shell=True)
except ImportError:
pass
@ -279,29 +358,41 @@ def stop():
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--no-restart", is_flag=True, default=False, help=(
"Whether to skip restarting Ray services during the update. "
"This avoids interrupting running jobs."))
"--no-restart",
is_flag=True,
default=False,
help=("Whether to skip restarting Ray services during the update. "
"This avoids interrupting running jobs."))
@click.option(
"--min-workers", required=False, type=int, help=(
"Override the configured min worker node count for the cluster."))
"--min-workers",
required=False,
type=int,
help=("Override the configured min worker node count for the cluster."))
@click.option(
"--max-workers", required=False, type=int, help=(
"Override the configured max worker node count for the cluster."))
"--max-workers",
required=False,
type=int,
help=("Override the configured max worker node count for the cluster."))
@click.option(
"--yes", "-y", is_flag=True, default=False, help=(
"Don't ask for confirmation."))
def create_or_update(
cluster_config_file, min_workers, max_workers, no_restart, yes):
create_or_update_cluster(
cluster_config_file, min_workers, max_workers, no_restart, yes)
"--yes",
"-y",
is_flag=True,
default=False,
help=("Don't ask for confirmation."))
def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
yes):
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
no_restart, yes)
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--yes", "-y", is_flag=True, default=False, help=(
"Don't ask for confirmation."))
"--yes",
"-y",
is_flag=True,
default=False,
help=("Don't ask for confirmation."))
def teardown(cluster_config_file, yes):
teardown_cluster(cluster_config_file, yes)

View file

@ -41,16 +41,12 @@ PROCESS_TYPE_WEB_UI = "web_ui"
# important because it determines the order in which these processes will be
# terminated when Ray exits, and certain orders will cause errors to be logged
# to the screen.
all_processes = OrderedDict([(PROCESS_TYPE_MONITOR, []),
(PROCESS_TYPE_LOG_MONITOR, []),
(PROCESS_TYPE_WORKER, []),
(PROCESS_TYPE_RAYLET, []),
(PROCESS_TYPE_LOCAL_SCHEDULER, []),
(PROCESS_TYPE_PLASMA_MANAGER, []),
(PROCESS_TYPE_PLASMA_STORE, []),
(PROCESS_TYPE_GLOBAL_SCHEDULER, []),
(PROCESS_TYPE_REDIS_SERVER, []),
(PROCESS_TYPE_WEB_UI, [])],)
all_processes = OrderedDict(
[(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []),
(PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []),
(PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []),
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []),
(PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], )
# True if processes are run in the valgrind profiler.
RUN_RAYLET_PROFILER = False
@ -82,17 +78,15 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/ray/raylet/raylet_monitor")
RAYLET_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/ray/raylet/raylet")
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
# ObjectStoreAddress tuples contain all information necessary to connect to an
# object store. The fields are:
# - name: The socket name for the object store
# - manager_name: The socket name for the object store manager
# - manager_port: The Internet port that the object store manager listens on
ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
"manager_name",
"manager_port"])
ObjectStoreAddress = namedtuple("ObjectStoreAddress",
["name", "manager_name", "manager_port"])
def address(ip_address, port):
@ -133,8 +127,10 @@ def kill_process(p):
if p.poll() is not None:
# The process has already terminated.
return True
if any([RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER]):
if any([
RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER
]):
# Give process signal to write profiler data.
os.kill(p.pid, signal.SIGINT)
# Wait for profiling data to be written.
@ -260,8 +256,8 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
for log_file in log_files:
if log_file is not None:
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
# The name of the key storing the list of log filenames for this IP
# address.
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
@ -304,8 +300,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
while counter < num_retries:
try:
# Run some random command and see if it worked.
print("Waiting for redis server at {}:{} to respond..."
.format(redis_ip_address, redis_port))
print("Waiting for redis server at {}:{} to respond...".format(
redis_ip_address, redis_port))
redis_client.client_list()
except redis.ConnectionError as e:
# Wait a little bit.
@ -427,17 +423,19 @@ def start_credis(node_ip_address,
"""
components = ["credis_master", "credis_head", "credis_tail"]
modules = [CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE,
CREDIS_MEMBER_MODULE]
modules = [
CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE
]
ports = []
for i, component in enumerate(components):
stdout_file, stderr_file = new_log_files(
component, redirect_output)
stdout_file, stderr_file = new_log_files(component, redirect_output)
new_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
stdout_file=stdout_file, stderr_file=stderr_file,
node_ip_address=node_ip_address,
port=port,
stdout_file=stdout_file,
stderr_file=stderr_file,
cleanup=cleanup,
module=modules[i],
executable=CREDIS_EXECUTABLE)
@ -456,8 +454,7 @@ def start_credis(node_ip_address,
# Register credis master in redis
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_client.set("credis_address", credis_address)
return credis_address
@ -509,9 +506,11 @@ def start_redis(node_ip_address,
"number of Redis shards.")
assigned_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
node_ip_address=node_ip_address,
port=port,
redis_max_clients=redis_max_clients,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
if port is not None:
assert assigned_port == port
@ -540,7 +539,8 @@ def start_redis(node_ip_address,
node_ip_address=node_ip_address,
port=redis_shard_ports[i],
redis_max_clients=redis_max_clients,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
if redis_shard_ports[i] is not None:
assert redis_shard_port == redis_shard_ports[i]
@ -601,11 +601,13 @@ def start_redis_instance(node_ip_address="127.0.0.1",
while counter < num_retries:
if counter > 0:
print("Redis failed to start, retrying now.")
p = subprocess.Popen([executable,
"--port", str(port),
"--loglevel", "warning",
"--loadmodule", module],
stdout=stdout_file, stderr=stderr_file)
p = subprocess.Popen(
[
executable, "--port",
str(port), "--loglevel", "warning", "--loadmodule", module
],
stdout=stdout_file,
stderr=stderr_file)
time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable
# did not exit within 0.1 seconds).
@ -652,8 +654,8 @@ def start_redis_instance(node_ip_address="127.0.0.1",
# Increase the hard and soft limits for the redis client pubsub buffer to
# 128MB. This is a hack to make it less likely for pubsub messages to be
# dropped and for pubsub connections to therefore be killed.
cur_config = (redis_client.config_get("client-output-buffer-limit")
["client-output-buffer-limit"])
cur_config = (redis_client.config_get("client-output-buffer-limit")[
"client-output-buffer-limit"])
cur_config_list = cur_config.split()
assert len(cur_config_list) == 12
cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"]
@ -662,13 +664,17 @@ def start_redis_instance(node_ip_address="127.0.0.1",
# Put a time stamp in Redis to indicate when it was started.
redis_client.set("redis_start_time", time.time())
# Record the log files in Redis.
record_log_files_in_redis(address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
return port, p
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=cleanup):
def start_log_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=cleanup):
"""Start a log monitor process.
Args:
@ -684,20 +690,25 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
Python process that imported services exits.
"""
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"log_monitor.py")
p = subprocess.Popen([sys.executable, "-u", log_monitor_filepath,
"--redis-address", redis_address,
"--node-ip-address", node_ip_address],
stdout=stdout_file, stderr=stderr_file)
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
p = subprocess.Popen(
[
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
],
stdout=stdout_file,
stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address,
stdout_file=None, stderr_file=None, cleanup=True):
def start_global_scheduler(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Start a global scheduler process.
Args:
@ -712,10 +723,11 @@ def start_global_scheduler(redis_address, node_ip_address,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
p = global_scheduler.start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
@ -737,8 +749,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
"""
new_env = os.environ.copy()
notebook_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"WebUI.ipynb")
os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb")
# We copy the notebook file so that the original doesn't get modified by
# the user.
random_ui_id = random.randint(0, 100000)
@ -759,19 +770,23 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
# We generate the token used for authentication ourselves to avoid
# querying the jupyter server.
token = binascii.hexlify(os.urandom(24)).decode("ascii")
command = ["jupyter", "notebook", "--no-browser",
"--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
"--NotebookApp.open_browser=False",
"--NotebookApp.token={}".format(token)]
command = [
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
"--NotebookApp.open_browser=False",
"--NotebookApp.token={}".format(token)
]
# If the user is root, add the --allow-root flag.
if os.geteuid() == 0:
command.append("--allow-root")
try:
ui_process = subprocess.Popen(command, env=new_env,
cwd=new_notebook_directory,
stdout=stdout_file, stderr=stderr_file)
ui_process = subprocess.Popen(
command,
env=new_env,
cwd=new_notebook_directory,
stdout=stdout_file,
stderr=stderr_file)
except Exception:
print("Failed to start the UI, you may need to run "
"'pip install jupyter'.")
@ -836,8 +851,8 @@ def start_local_scheduler(redis_address,
# Check that the number of GPUs that the local scheduler wants doesn't
# excede the amount allowed by CUDA_VISIBLE_DEVICES.
if ("GPU" in resources and gpu_ids is not None and
resources["GPU"] > len(gpu_ids)):
if ("GPU" in resources and gpu_ids is not None
and resources["GPU"] > len(gpu_ids)):
raise Exception("Attempting to start local scheduler with {} GPUs, "
"but CUDA_VISIBLE_DEVICES contains {}.".format(
resources["GPU"], gpu_ids))
@ -906,21 +921,14 @@ def start_raylet(redis_address,
"--node-ip-address={} "
"--object-store-name={} "
"--raylet-name={} "
"--redis-address={}"
.format(sys.executable,
worker_path,
node_ip_address,
plasma_store_name,
raylet_name,
redis_address))
"--redis-address={}".format(
sys.executable, worker_path, node_ip_address,
plasma_store_name, raylet_name, redis_address))
command = [RAYLET_EXECUTABLE,
raylet_name,
plasma_store_name,
node_ip_address,
gcs_ip_address,
gcs_port,
start_worker_command]
command = [
RAYLET_EXECUTABLE, raylet_name, plasma_store_name, node_ip_address,
gcs_ip_address, gcs_port, start_worker_command
]
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
@ -931,12 +939,18 @@ def start_raylet(redis_address,
return raylet_name
def start_objstore(node_ip_address, redis_address,
object_manager_port=None, store_stdout_file=None,
store_stderr_file=None, manager_stdout_file=None,
manager_stderr_file=None, objstore_memory=None,
cleanup=True, plasma_directory=None,
huge_pages=False, use_raylet=False):
def start_objstore(node_ip_address,
redis_address,
object_manager_port=None,
store_stdout_file=None,
store_stderr_file=None,
manager_stdout_file=None,
manager_stderr_file=None,
objstore_memory=None,
cleanup=True,
plasma_directory=None,
huge_pages=False,
use_raylet=False):
"""This method starts an object store process.
Args:
@ -1013,24 +1027,24 @@ def start_objstore(node_ip_address, redis_address,
if object_manager_port is not None:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
node_ip_address=node_ip_address,
num_retries=1,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
node_ip_address=node_ip_address,
num_retries=1,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
assert plasma_manager_port == object_manager_port
else:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
else:
plasma_manager_port = None
plasma_manager_name = None
@ -1049,9 +1063,15 @@ def start_objstore(node_ip_address, redis_address,
plasma_manager_port)
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
local_scheduler_name, redis_address, worker_path,
stdout_file=None, stderr_file=None, cleanup=True):
def start_worker(node_ip_address,
object_store_name,
object_store_manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""This method starts a worker process.
Args:
@ -1072,14 +1092,14 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
Python process that imported services exits. This is True by
default.
"""
command = [sys.executable,
"-u",
worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address)]
command = [
sys.executable, "-u", worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address)
]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_WORKER].append(p)
@ -1087,8 +1107,12 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
[stdout_file, stderr_file])
def start_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True, autoscaling_config=None):
def start_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True,
autoscaling_config=None):
"""Run a process to monitor the other processes.
Args:
@ -1105,12 +1129,12 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
default.
autoscaling_config: path to autoscaling config file.
"""
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = [sys.executable,
"-u",
monitor_path,
"--redis-address=" + str(redis_address)]
monitor_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
command = [
sys.executable, "-u", monitor_path,
"--redis-address=" + str(redis_address)
]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@ -1120,8 +1144,10 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
[stdout_file, stderr_file])
def start_raylet_monitor(redis_address, stdout_file=None,
stderr_file=None, cleanup=True):
def start_raylet_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Run a process to monitor the other processes.
Args:
@ -1136,9 +1162,7 @@ def start_raylet_monitor(redis_address, stdout_file=None,
default.
"""
gcs_ip_address, gcs_port = redis_address.split(":")
command = [RAYLET_MONITOR_EXECUTABLE,
gcs_ip_address,
gcs_port]
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_MONITOR].append(p)
@ -1238,16 +1262,17 @@ def start_ray_processes(address_info=None,
workers_per_local_scheduler = []
for resource_dict in resources:
cpus = resource_dict.get("CPU")
workers_per_local_scheduler.append(cpus if cpus is not None
else psutil.cpu_count())
workers_per_local_scheduler.append(cpus if cpus is not None else
psutil.cpu_count())
if address_info is None:
address_info = {}
address_info["node_ip_address"] = node_ip_address
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
worker_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
# Start Redis if there isn't already an instance running. TODO(rkn): We are
# suppressing the output of Redis because on Linux it prints a bunch of
@ -1257,7 +1282,8 @@ def start_ray_processes(address_info=None,
redis_shards = address_info.get("redis_shards", [])
if redis_address is None:
redis_address, redis_shards = start_redis(
node_ip_address, port=redis_port,
node_ip_address,
port=redis_port,
redis_shard_ports=redis_shard_ports,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
@ -1274,23 +1300,25 @@ def start_ray_processes(address_info=None,
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files(
"monitor", redirect_output)
start_monitor(redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
start_monitor(
redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
if use_raylet:
start_raylet_monitor(redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
address_info["redis_shards"] = redis_shards
@ -1299,21 +1327,23 @@ def start_ray_processes(address_info=None,
if include_log_monitor:
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
"log_monitor", redirect_output=True)
start_log_monitor(redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
start_log_monitor(
redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
# Start the global scheduler, if necessary.
if include_global_scheduler and not use_raylet:
global_scheduler_stdout_file, global_scheduler_stderr_file = (
new_log_files("global_scheduler", redirect_output))
start_global_scheduler(redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
@ -1324,9 +1354,8 @@ def start_ray_processes(address_info=None,
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
# Get the ports to use for the object managers if any are provided.
object_manager_ports = (address_info["object_manager_ports"]
if "object_manager_ports" in address_info
else None)
object_manager_ports = (address_info["object_manager_ports"] if
"object_manager_ports" in address_info else None)
if not isinstance(object_manager_ports, list):
object_manager_ports = num_local_schedulers * [object_manager_ports]
assert len(object_manager_ports) == num_local_schedulers
@ -1347,7 +1376,8 @@ def start_ray_processes(address_info=None,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
objstore_memory=object_store_memory,
cleanup=cleanup, plasma_directory=plasma_directory,
cleanup=cleanup,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet)
object_store_addresses.append(object_store_address)
@ -1355,8 +1385,8 @@ def start_ray_processes(address_info=None,
# Start any local schedulers that do not yet exist.
if not use_raylet:
for i in range(len(local_scheduler_socket_names),
num_local_schedulers):
for i in range(
len(local_scheduler_socket_names), num_local_schedulers):
# Connect the local scheduler to the object store at the same
# index.
object_store_address = object_store_addresses[i]
@ -1374,8 +1404,9 @@ def start_ray_processes(address_info=None,
# redirect the worker output, then we cannot redirect the local
# scheduler output.
local_scheduler_stdout_file, local_scheduler_stderr_file = (
new_log_files("local_scheduler_{}".format(i),
redirect_output=redirect_worker_output))
new_log_files(
"local_scheduler_{}".format(i),
redirect_output=redirect_worker_output))
local_scheduler_name = start_local_scheduler(
redis_address,
node_ip_address,
@ -1398,17 +1429,18 @@ def start_ray_processes(address_info=None,
else:
# Start the raylet. TODO(rkn): Modify this to allow starting
# multiple raylets on the same machine.
raylet_stdout_file, raylet_stderr_file = (
new_log_files("raylet_{}".format(i),
redirect_output=redirect_output))
address_info["raylet_socket_names"] = [start_raylet(
redis_address,
node_ip_address,
object_store_addresses[i].name,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=cleanup)]
raylet_stdout_file, raylet_stderr_file = (new_log_files(
"raylet_{}".format(i), redirect_output=redirect_output))
address_info["raylet_socket_names"] = [
start_raylet(
redis_address,
node_ip_address,
object_store_addresses[i].name,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=cleanup)
]
if not use_raylet:
# Start any workers that the local scheduler has not already started.
@ -1419,28 +1451,30 @@ def start_ray_processes(address_info=None,
for j in range(num_local_scheduler_workers):
worker_stdout_file, worker_stderr_file = new_log_files(
"worker_{}_{}".format(i, j), redirect_output)
start_worker(node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
start_worker(
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
workers_per_local_scheduler[i] -= 1
# Make sure that we've started all the workers.
assert(sum(workers_per_local_scheduler) == 0)
assert (sum(workers_per_local_scheduler) == 0)
# Try to start the web UI.
if include_webui:
ui_stdout_file, ui_stderr_file = new_log_files(
"webui", redirect_output=True)
address_info["webui_url"] = start_ui(redis_address,
stdout_file=ui_stdout_file,
stderr_file=ui_stderr_file,
cleanup=cleanup)
address_info["webui_url"] = start_ui(
redis_address,
stdout_file=ui_stdout_file,
stderr_file=ui_stderr_file,
cleanup=cleanup)
else:
address_info["webui_url"] = ""
# Return the addresses of the relevant processes.
@ -1500,21 +1534,24 @@ def start_ray_node(node_ip_address,
A dictionary of the address information for the processes that were
started.
"""
address_info = {"redis_address": redis_address,
"object_manager_ports": object_manager_ports}
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
redirect_worker_output=redirect_worker_output,
redirect_output=redirect_output,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages)
address_info = {
"redis_address": redis_address,
"object_manager_ports": object_manager_ports
}
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
redirect_worker_output=redirect_worker_output,
redirect_output=redirect_output,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages)
def start_ray_head(address_info=None,

View file

@ -7,11 +7,10 @@ import funcsigs
from ray.utils import is_cython
FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
"arg_defaults",
"arg_is_positionals",
"keyword_names",
"function_name"])
FunctionSignature = namedtuple("FunctionSignature", [
"arg_names", "arg_defaults", "arg_is_positionals", "keyword_names",
"function_name"
])
"""This class is used to represent a function signature.
Attributes:
@ -49,13 +48,16 @@ def get_signature_params(func):
# The first condition for Cython functions, the latter for Cython instance
# methods
if is_cython(func):
attrs = ["__code__", "__annotations__",
"__defaults__", "__kwdefaults__"]
attrs = [
"__code__", "__annotations__", "__defaults__", "__kwdefaults__"
]
if all([hasattr(func, attr) for attr in attrs]):
original_func = func
def func(): return
def func():
return
for attr in attrs:
setattr(func, attr, getattr(original_func, attr))
else:
@ -130,8 +132,8 @@ def extract_signature(func, ignore_first=False):
if ignore_first:
if len(sig_params) == 0:
raise Exception("Methods must take a 'self' argument, but the "
"method '{}' does not have one."
.format(func.__name__))
"method '{}' does not have one.".format(
func.__name__))
sig_params = sig_params[1:]
# Extract the names of the keyword arguments.
@ -183,8 +185,8 @@ def extend_args(function_signature, args, kwargs):
for keyword_name in kwargs:
if keyword_name not in keyword_names:
raise Exception("The name '{}' is not a valid keyword argument "
"for the function '{}'."
.format(keyword_name, function_name))
"for the function '{}'.".format(
keyword_name, function_name))
# Fill in the remaining arguments.
zipped_info = list(zip(arg_names, arg_defaults,
@ -201,12 +203,12 @@ def extend_args(function_signature, args, kwargs):
# can be omitted.
if not is_positional:
raise Exception("No value was provided for the argument "
"'{}' for the function '{}'."
.format(keyword_name, function_name))
"'{}' for the function '{}'.".format(
keyword_name, function_name))
too_many_arguments = (len(args) > len(arg_names) and
(len(arg_is_positionals) == 0 or
not arg_is_positionals[-1]))
too_many_arguments = (len(args) > len(arg_names)
and (len(arg_is_positionals) == 0
or not arg_is_positionals[-1]))
if too_many_arguments:
raise Exception("Too many arguments were passed to the function '{}'"
.format(function_name))

View file

@ -13,6 +13,7 @@ import numpy as np
def handle_int(a, b):
return a + 1, b + 1
# Test timing
@ -25,6 +26,7 @@ def empty_function():
def trivial_function():
return 1
# Test keyword arguments
@ -42,6 +44,7 @@ def keyword_fct2(a="hello", b="world"):
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ -56,17 +59,21 @@ def varargs_fct2(a, *b):
try:
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
except Exception:
kwargs_exception_thrown = True
try:
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
except Exception:
varargs_and_kwargs_exception_thrown = True
@ -88,6 +95,7 @@ def throw_exception_fct2():
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ -101,6 +109,7 @@ def python_mode_g(x):
x[0] = 1
return x
# test no return values

View file

@ -48,8 +48,8 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20):
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes,
num_nodes))
"expecting {} nodes.".format(
num_ready_nodes, num_nodes))
time.sleep(0.1)
# If we get here then we timed out.

View file

@ -9,14 +9,7 @@ from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search
__all__ = [
"Trainable",
"TrainingResult",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run_experiments",
"Experiment"
"Trainable", "TrainingResult", "TuneError", "grid_search", "register_env",
"register_trainable", "run_experiments", "Experiment"
]

View file

@ -35,10 +35,13 @@ class AsyncHyperBandScheduler(FIFOScheduler):
halving rate, specified by the reduction factor.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=100,
grace_period=10, reduction_factor=3, brackets=3):
def __init__(self,
time_attr='training_iteration',
reward_attr='episode_reward_mean',
max_t=100,
grace_period=10,
reduction_factor=3,
brackets=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
@ -51,8 +54,10 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._trial_info = {} # Stores Trial -> Bracket
# Tracks state for new trial add
self._brackets = [_Bracket(
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
self._brackets = [
_Bracket(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._counter = 0 # for
self._num_stopped = 0
self._reward_attr = reward_attr
@ -60,7 +65,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
def on_trial_add(self, trial_runner, trial):
sizes = np.array([len(b._rungs) for b in self._brackets])
probs = np.e ** (sizes - sizes.max())
probs = np.e**(sizes - sizes.max())
normalized = probs / probs.sum()
idx = np.random.choice(len(self._brackets), p=normalized)
self._trial_info[trial.trial_id] = self._brackets[idx]
@ -71,28 +76,23 @@ class AsyncHyperBandScheduler(FIFOScheduler):
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
action = bracket.on_result(trial, getattr(result, self._time_attr),
getattr(result, self._reward_attr))
if action == TrialScheduler.STOP:
self._num_stopped += 1
return action
def on_trial_complete(self, trial_runner, trial, result):
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
bracket.on_result(trial, getattr(result, self._time_attr),
getattr(result, self._reward_attr))
del self._trial_info[trial.trial_id]
def on_trial_remove(self, trial_runner, trial):
del self._trial_info[trial.trial_id]
def debug_string(self):
out = "Using AsyncHyperBand: num_stopped={}".format(
self._num_stopped)
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
@ -111,6 +111,7 @@ class _Bracket():
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
@ -140,9 +141,10 @@ class _Bracket():
return action
def debug_str(self):
iters = " | ".join(
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs])
iters = " | ".join([
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs
])
return "Bracket: " + iters

View file

@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
@ -24,8 +23,8 @@ def json_to_resources(data):
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0),
data.get("extra_cpu", 0), data.get("extra_gpu", 0))
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))
def resources_to_json(resources):
@ -50,59 +49,85 @@ def make_parser(**kwargs):
# Note: keep this in sync with rllib/train.py
parser.add_argument(
"--run", default=None, type=str,
"--run",
default=None,
type=str,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
parser.add_argument(
"--stop", default="{}", type=json.loads,
"--stop",
default="{}",
type=json.loads,
help="The stopping criteria, specified in JSON. The keys may be any "
"field in TrainingResult, e.g. "
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
"after 600 seconds or 100k timesteps, whichever is reached first.")
parser.add_argument(
"--config", default="{}", type=json.loads,
"--config",
default="{}",
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--resources", help="Deprecated, use --trial-resources.",
type=lambda v: _tune_error(
"The `resources` argument is no longer supported. "
"Use `trial_resources` or --trial-resources instead."))
"--resources",
help="Deprecated, use --trial-resources.",
type=lambda v: _tune_error("The `resources` argument is no longer "
"supported. Use `trial_resources` or "
"--trial-resources instead."))
parser.add_argument(
"--trial-resources", default='{"cpu": 1}', type=json_to_resources,
"--trial-resources",
default='{"cpu": 1}',
type=json_to_resources,
help="Machine resources to allocate per trial, e.g. "
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
"unless you specify them here.")
parser.add_argument(
"--repeat", default=1, type=int,
"--repeat",
default=1,
type=int,
help="Number of times to repeat each trial.")
parser.add_argument(
"--local-dir", default=DEFAULT_RESULTS_DIR, type=str,
"--local-dir",
default=DEFAULT_RESULTS_DIR,
type=str,
help="Local dir to save training results to. Defaults to '{}'.".format(
DEFAULT_RESULTS_DIR))
parser.add_argument(
"--upload-dir", default="", type=str,
"--upload-dir",
default="",
type=str,
help="Optional URI to sync training results to (e.g. s3://bucket).")
parser.add_argument(
"--checkpoint-freq", default=0, type=int,
"--checkpoint-freq",
default=0,
type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--max-failures", default=3, type=int,
"--max-failures",
default=3,
type=int,
help="Try to recover a trial from its last checkpoint at least this "
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
"--scheduler",
default="FIFO",
type=str,
help="FIFO (default), MedianStopping, AsyncHyperBand,"
"HyperBand, or HyperOpt.")
"HyperBand, or HyperOpt.")
parser.add_argument(
"--scheduler-config", default="{}", type=json.loads,
"--scheduler-config",
default="{}",
type=json.loads,
help="Config options to pass to the scheduler.")
# Note: this currently only makes sense when running a single trial
parser.add_argument("--restore", default=None, type=str,
help="If specified, restore from this checkpoint.")
parser.add_argument(
"--restore",
default=None,
type=str,
help="If specified, restore from this checkpoint.")
return parser

View file

@ -60,18 +60,27 @@ if __name__ == "__main__":
# `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
ahb = AsyncHyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
grace_period=5, max_t=100)
time_attr="timesteps_total",
reward_attr="episode_reward_mean",
grace_period=5,
max_t=100)
run_experiments({
"asynchyperband_test": {
"run": "my_class",
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"trial_resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=ahb)
run_experiments(
{
"asynchyperband_test": {
"run": "my_class",
"stop": {
"training_iteration": 1 if args.smoke_test else 99999
},
"repeat": 20,
"trial_resources": {
"cpu": 1,
"gpu": 0
},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
},
scheduler=ahb)

View file

@ -59,7 +59,8 @@ if __name__ == "__main__":
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
time_attr="timesteps_total",
reward_attr="episode_reward_mean",
max_t=100)
exp = Experiment(

View file

@ -12,8 +12,8 @@ def easy_objective(config, reporter):
time.sleep(0.2)
reporter(
timesteps_total=1,
episode_reward_mean=-((config["height"]-14) ** 2
+ abs(config["width"]-3)))
episode_reward_mean=-(
(config["height"] - 14)**2 + abs(config["width"] - 3)))
time.sleep(0.2)
@ -34,12 +34,18 @@ if __name__ == '__main__':
'height': hp.uniform('height', -100, 100),
}
config = {"my_exp": {
config = {
"my_exp": {
"run": "exp",
"repeat": 5 if args.smoke_test else 1000,
"stop": {"training_iteration": 1},
"stop": {
"training_iteration": 1
},
"config": {
"space": space}}}
"space": space
}
}
}
hpo_sched = HyperOptScheduler()
run_experiments(config, verbose=False, scheduler=hpo_sched)

View file

@ -42,8 +42,11 @@ class MyTrainableClass(Trainable):
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps(
{"timestep": self.timestep, "value": self.current_value}))
f.write(
json.dumps({
"timestep": self.timestep,
"value": self.current_value
}))
return path
def _restore(self, checkpoint_path):
@ -63,7 +66,8 @@ if __name__ == "__main__":
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration", reward_attr="episode_reward_mean",
time_attr="training_iteration",
reward_attr="episode_reward_mean",
perturbation_interval=10,
hyperparam_mutations={
# Allow for scaling-based perturbations, with a uniform backing
@ -74,15 +78,23 @@ if __name__ == "__main__":
})
# Try to find the best factor 1 and factor 2
run_experiments({
"pbt_test": {
"run": "my_class",
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
"repeat": 10,
"trial_resources": {"cpu": 1, "gpu": 0},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
}, scheduler=pbt, verbose=False)
run_experiments(
{
"pbt_test": {
"run": "my_class",
"stop": {
"training_iteration": 2 if args.smoke_test else 99999
},
"repeat": 10,
"trial_resources": {
"cpu": 1,
"gpu": 0
},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
},
scheduler=pbt,
verbose=False)

View file

@ -1,5 +1,4 @@
#!/usr/bin/env python
"""Example of using PBT with RLlib.
Note that this requires a cluster with at least 8 GPUs in order for all trials
@ -30,7 +29,8 @@ if __name__ == "__main__":
return config
pbt = PopulationBasedTraining(
time_attr="time_total_s", reward_attr="episode_reward_mean",
time_attr="time_total_s",
reward_attr="episode_reward_mean",
perturbation_interval=120,
resample_probability=0.25,
# Specifies the mutations of these hyperparams
@ -45,26 +45,40 @@ if __name__ == "__main__":
custom_explore_fn=explore)
ray.init()
run_experiments({
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"kl_coeff": 1.0,
"num_workers": 8,
"devices": ["/gpu:0"],
"model": {"free_log_std": True},
# These params are tuned from a fixed starting value.
"lambda": 0.95,
"clip_param": 0.2,
"sgd_stepsize": 1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
run_experiments(
{
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"trial_resources": {
"cpu": 4,
"gpu": 1
},
"config": {
"kl_coeff":
1.0,
"num_workers":
8,
"devices": ["/gpu:0"],
"model": {
"free_log_std": True
},
# These params are tuned from a fixed starting value.
"lambda":
0.95,
"clip_param":
0.2,
"sgd_stepsize":
1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter":
lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize":
lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
lambda spec: random.choice([10000, 20000, 40000])
},
},
},
}, scheduler=pbt)
scheduler=pbt)

View file

@ -29,12 +29,10 @@ from ray.tune import Trainable
from ray.tune import TrainingResult
from ray.tune.pbt import PopulationBasedTraining
num_classes = 10
class Cifar10Model(Trainable):
def _read_data(self):
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
@ -54,27 +52,51 @@ class Cifar10Model(Trainable):
x = Input(shape=(32, 32, 3))
y = x
y = Convolution2D(
filters=64, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=64,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=64, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=64,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Convolution2D(
filters=128, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=128,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=128, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=128,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Convolution2D(
filters=256, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=256,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=256, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=256,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Flatten()(y)
@ -91,9 +113,10 @@ class Cifar10Model(Trainable):
model = self._build_model(x_train.shape[1:])
opt = tf.keras.optimizers.Adadelta()
model.compile(loss="categorical_crossentropy",
optimizer=opt,
metrics=["accuracy"])
model.compile(
loss="categorical_crossentropy",
optimizer=opt,
metrics=["accuracy"])
self.model = model
def _train(self):
@ -134,8 +157,7 @@ class Cifar10Model(Trainable):
# loss, accuracy
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
return TrainingResult(timesteps_this_iter=10,
mean_accuracy=accuracy)
return TrainingResult(timesteps_this_iter=10, mean_accuracy=accuracy)
def _save(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
@ -154,15 +176,17 @@ class Cifar10Model(Trainable):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--smoke-test",
action="store_true",
help="Finish quickly for testing")
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
register_trainable("train_cifar10", Cifar10Model)
train_spec = {
"run": "train_cifar10",
"trial_resources": {"cpu": 1, "gpu": 1},
"trial_resources": {
"cpu": 1,
"gpu": 1
},
"stop": {
"mean_accuracy": 0.80,
"timesteps_total": 300,
@ -170,7 +194,7 @@ if __name__ == "__main__":
"config": {
"epochs": 1,
"batch_size": 64,
"lr": grid_search([10 ** -4, 10 ** -5]),
"lr": grid_search([10**-4, 10**-5]),
"decay": lambda spec: spec.config.lr / 100.0,
"dropout": grid_search([0.25, 0.5]),
},
@ -178,17 +202,17 @@ if __name__ == "__main__":
}
if args.smoke_test:
train_spec["config"]["lr"] = 10 ** -4
train_spec["config"]["lr"] = 10**-4
train_spec["config"]["dropout"] = 0.5
ray.init()
pbt = PopulationBasedTraining(
time_attr="timesteps_total", reward_attr="mean_accuracy",
time_attr="timesteps_total",
reward_attr="mean_accuracy",
perturbation_interval=10,
hyperparam_mutations={
"dropout": lambda _: np.random.uniform(0, 1),
})
run_experiments({"pbt_cifar10": train_spec},
scheduler=pbt)
run_experiments({"pbt_cifar10": train_spec}, scheduler=pbt)

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
@ -42,7 +41,7 @@ import tensorflow as tf
FLAGS = None
status_reporter = None # used to report training status back to Ray
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def deepnn(x):
@ -90,7 +89,7 @@ def deepnn(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@ -173,7 +172,10 @@ def main(_):
batch = mnist.train.next_batch(50)
if i % 10 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0})
x: batch[0],
y_: batch[1],
keep_prob: 1.0
})
# !!! Report status to ray.tune !!!
if status_reporter:
@ -181,11 +183,17 @@ def main(_):
timesteps_total=i, mean_accuracy=train_accuracy)
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
train_step.run(feed_dict={
x: batch[0],
y_: batch[1],
keep_prob: 0.5
})
print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
x: mnist.test.images,
y_: mnist.test.labels,
keep_prob: 1.0
}))
# !!! Entrypoint for ray.tune !!!
@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
activation_fn = getattr(tf.nn, config['activation'])
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@ -213,8 +223,8 @@ if __name__ == '__main__':
'run': 'train_mnist',
'repeat': 10,
'stop': {
'mean_accuracy': 0.99,
'timesteps_total': 600,
'mean_accuracy': 0.99,
'timesteps_total': 600,
},
'config': {
'activation': grid_search(['relu', 'elu', 'tanh']),
@ -228,8 +238,12 @@ if __name__ == '__main__':
ray.init()
from ray.tune.async_hyperband import AsyncHyperBandScheduler
run_experiments({'tune_mnist_test': mnist_spec},
scheduler=AsyncHyperBandScheduler(
time_attr="timesteps_total",
reward_attr="mean_accuracy",
max_t=600,))
run_experiments(
{
'tune_mnist_test': mnist_spec
},
scheduler=AsyncHyperBandScheduler(
time_attr="timesteps_total",
reward_attr="mean_accuracy",
max_t=600,
))

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
@ -42,7 +41,7 @@ import tensorflow as tf
FLAGS = None
status_reporter = None # used to report training status back to Ray
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def deepnn(x):
@ -90,7 +89,7 @@ def deepnn(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@ -173,7 +172,10 @@ def main(_):
batch = mnist.train.next_batch(50)
if i % 10 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0})
x: batch[0],
y_: batch[1],
keep_prob: 1.0
})
# !!! Report status to ray.tune !!!
if status_reporter:
@ -181,11 +183,17 @@ def main(_):
timesteps_total=i, mean_accuracy=train_accuracy)
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
train_step.run(feed_dict={
x: batch[0],
y_: batch[1],
keep_prob: 0.5
})
print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
x: mnist.test.images,
y_: mnist.test.labels,
keep_prob: 1.0
}))
# !!! Entrypoint for ray.tune !!!
@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
activation_fn = getattr(tf.nn, config['activation'])
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@ -212,8 +222,8 @@ if __name__ == '__main__':
mnist_spec = {
'run': 'train_mnist',
'stop': {
'mean_accuracy': 0.99,
'time_total_s': 600,
'mean_accuracy': 0.99,
'time_total_s': 600,
},
'config': {
'activation': grid_search(['relu', 'elu', 'tanh']),

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
@ -39,7 +38,7 @@ from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def setupCNN(x):
@ -85,7 +84,7 @@ def setupCNN(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@ -182,14 +181,18 @@ class TrainMNIST(Trainable):
self.sess.run(
self.train_step,
feed_dict={
self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5
self.x: batch[0],
self.y_: batch[1],
self.keep_prob: 0.5
})
batch = self.mnist.train.next_batch(50)
train_accuracy = self.sess.run(
self.accuracy,
feed_dict={
self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0
self.x: batch[0],
self.y_: batch[1],
self.keep_prob: 1.0
})
self.iterations += 1
@ -215,11 +218,11 @@ if __name__ == '__main__':
mnist_spec = {
'run': 'my_class',
'stop': {
'mean_accuracy': 0.99,
'time_total_s': 600,
'mean_accuracy': 0.99,
'time_total_s': 600,
},
'config': {
'learning_rate': lambda spec: 10 ** np.random.uniform(-5, -3),
'learning_rate': lambda spec: 10**np.random.uniform(-5, -3),
'activation': grid_search(['relu', 'elu', 'tanh']),
},
"repeat": 10,
@ -231,8 +234,6 @@ if __name__ == '__main__':
ray.init()
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="mean_accuracy",
max_t=100)
time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100)
run_experiments(
{'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
run_experiments({'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)

View file

@ -35,14 +35,26 @@ class Experiment(object):
checkpoint at least this many times. Only applies if
checkpointing is enabled. Defaults to 3.
"""
def __init__(self, name, run, stop=None, config=None,
trial_resources=None, repeat=1, local_dir=None,
upload_dir="", checkpoint_freq=0, max_failures=3):
def __init__(self,
name,
run,
stop=None,
config=None,
trial_resources=None,
repeat=1,
local_dir=None,
upload_dir="",
checkpoint_freq=0,
max_failures=3):
spec = {
"run": run,
"stop": stop or {},
"config": config or {},
"trial_resources": trial_resources or {"cpu": 1, "gpu": 0},
"trial_resources": trial_resources or {
"cpu": 1,
"gpu": 0
},
"repeat": repeat,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"upload_dir": upload_dir,

View file

@ -91,8 +91,8 @@ class FunctionRunner(Trainable):
for k in self._default_config:
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, scrubbed_config, self._status_reporter)
self._runner = _RunnerThread(entrypoint, scrubbed_config,
self._status_reporter)
self._start_time = time.time()
self._last_reported_timestep = 0
self._runner.start()
@ -104,9 +104,8 @@ class FunctionRunner(Trainable):
def _train(self):
time.sleep(
self.config.get(
"script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
self.config.get("script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)

View file

@ -102,9 +102,8 @@ class HyperOptScheduler(FIFOScheduler):
self._hpopt_trials.refresh()
# Get new suggestion from
new_trials = self.algo(
new_ids, self.domain, self._hpopt_trials,
self.rstate.randint(2 ** 31 - 1))
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
self.rstate.randint(2**31 - 1))
self._hpopt_trials.insert_trial_docs(new_trials)
self._hpopt_trials.refresh()
new_trial = new_trials[0]
@ -112,8 +111,11 @@ class HyperOptScheduler(FIFOScheduler):
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
new_cfg.update(suggested_config)
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
for k, v in sorted(suggested_config.items())])
kv_str = "_".join([
"{}={}".format(k,
str(v)[:5])
for k, v in sorted(suggested_config.items())
])
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
# Keep this consistent with tune.variant_generator
@ -166,8 +168,7 @@ class HyperOptScheduler(FIFOScheduler):
del self._tune_to_hp[trial]
def _to_hyperopt_result(self, result):
return {"loss": -getattr(result, self._reward_attr),
"status": "ok"}
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
def _get_hyperopt_trial(self, tid):
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
@ -183,8 +184,9 @@ class HyperOptScheduler(FIFOScheduler):
experiments and trials left to run. If self._max_concurrent is None,
scheduler will add new trial if there is none that are pending.
"""
pending = [t for t in trial_runner.get_trials()
if t.status == Trial.PENDING]
pending = [
t for t in trial_runner.get_trials() if t.status == Trial.PENDING
]
if self._num_trials_left <= 0:
return
if self._max_concurrent is None:

View file

@ -66,9 +66,10 @@ class HyperBandScheduler(FIFOScheduler):
mentioned in the original HyperBand paper.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=81):
def __init__(self,
time_attr='training_iteration',
reward_attr='episode_reward_mean',
max_t=81):
assert max_t > 0, "Max (time_attr) not valid!"
FIFOScheduler.__init__(self)
self._eta = 3
@ -78,13 +79,12 @@ class HyperBandScheduler(FIFOScheduler):
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
# bracket initial iterations
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
self._get_r0 = lambda s: int((max_t * self._eta**(-s)))
self._hyperbands = [[]] # list of hyperband iterations
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
# Tracks state for new trial add
self._state = {"bracket": None,
"band_idx": 0}
self._state = {"bracket": None, "band_idx": 0}
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
@ -116,9 +116,9 @@ class HyperBandScheduler(FIFOScheduler):
cur_bracket = None
else:
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._max_t_attr, self._eta, s)
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
self._get_r0(s), self._max_t_attr,
self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
@ -217,11 +217,11 @@ class HyperBandScheduler(FIFOScheduler):
"""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
key=lambda b: b.completion_percentage()):
for bracket in sorted(
hyperband, key=lambda b: b.completion_percentage()):
for trial in bracket.current_trials():
if (trial.status == Trial.PENDING and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
return trial
return None
@ -258,6 +258,7 @@ class Bracket():
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
@ -287,8 +288,9 @@ class Bracket():
"""Checks if all iterations have completed.
TODO(rliaw): also check that `t.iterations == self._r`"""
return all(self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
return all(
self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
def finished(self):
return self._halves == 0 and self.cur_iter_done()
@ -379,7 +381,7 @@ class Bracket():
def _calculate_total_work(self, n, r, s):
work = 0
cumulative_r = r
for i in range(s+1):
for i in range(s + 1):
work += int(n) * int(r)
n /= self._eta
n = int(np.ceil(n))
@ -389,11 +391,11 @@ class Bracket():
def __repr__(self):
status = ", ".join([
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._cumul_r),
"completed={:.1%}".format(self.completion_percentage())
])
"Max Size (n)={}".format(self._n), "Milestone (r)={}".format(
self._cumul_r), "completed={:.1%}".format(
self.completion_percentage())
])
counts = collections.Counter([t.status for t in self._all_trials])
trial_statuses = ", ".join(sorted(
["{}: {}".format(k, v) for k, v in counts.items()]))
trial_statuses = ", ".join(
sorted(["{}: {}".format(k, v) for k, v in counts.items()]))
return "Bracket({}): {{{}}} ".format(status, trial_statuses)

View file

@ -13,7 +13,6 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
# Map from (logdir, remote_dir) -> syncer
_syncers = {}
@ -69,9 +68,8 @@ class _LogSyncer(object):
def sync_now(self, force=False):
self.last_sync_time = time.time()
if not self.worker_ip:
print(
"Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
print("Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return
if self.worker_ip == self.local_ip:
@ -80,23 +78,21 @@ class _LogSyncer(object):
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
if ssh_key is None or ssh_user is None:
print(
"Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
print("Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
return
if not distutils.spawn.find_executable("rsync"):
print("Error: log sync requires rsync to be installed.")
return
worker_to_local_sync_cmd = (
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
worker_to_local_sync_cmd = ((
"""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
ssh_key, ssh_user, self.worker_ip,
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
if self.remote_dir:
local_to_remote_sync_cmd = (
"aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
else:
local_to_remote_sync_cmd = None

View file

@ -110,9 +110,9 @@ def to_tf_values(result, path):
for attr, value in result.items():
if value is not None:
if type(value) in [int, float]:
values.append(tf.Summary.Value(
tag="/".join(path + [attr]),
simple_value=value))
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
return values
@ -125,8 +125,8 @@ class _TFLogger(Logger):
def on_result(self, result):
tmp = result._asdict()
for k in [
"config", "pid", "timestamp", "time_total_s",
"timesteps_total"]:
"config", "pid", "timestamp", "time_total_s", "timesteps_total"
]:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
@ -165,9 +165,9 @@ class _CustomEncoder(json.JSONEncoder):
return repr(o) if not np.isnan(o) else nan_str
_iterencode = json.encoder._make_iterencode(
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
return _iterencode(o, 0)
def default(self, value):

View file

@ -32,10 +32,13 @@ class MedianStoppingRule(FIFOScheduler):
time a trial reports. Defaults to True.
"""
def __init__(
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
grace_period=60.0, min_samples_required=3,
hard_stop=True, verbose=True):
def __init__(self,
time_attr="time_total_s",
reward_attr="episode_reward_mean",
grace_period=60.0,
min_samples_required=3,
hard_stop=True,
verbose=True):
FIFOScheduler.__init__(self)
self._stopped_trials = set()
self._completed_trials = set()
@ -103,9 +106,10 @@ class MedianStoppingRule(FIFOScheduler):
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
return np.mean(
[getattr(r, self._reward_attr)
for r in results if getattr(r, self._time_attr) <= t_max])
return np.mean([
getattr(r, self._reward_attr) for r in results
if getattr(r, self._time_attr) <= t_max
])
def _best_result(self, trial):
results = self._results[trial]

View file

@ -11,7 +11,6 @@ from ray.tune.trial import Trial
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.variant_generator import _format_vars
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
# the bottom PBT_QUANTILE fraction.
PBT_QUANTILE = 0.25
@ -27,9 +26,8 @@ class PBTTrialState(object):
self.last_perturbation_time = 0
def __repr__(self):
return str((
self.last_score, self.last_checkpoint,
self.last_perturbation_time))
return str((self.last_score, self.last_checkpoint,
self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
@ -51,12 +49,13 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
config[key] not in distribution:
new_config[key] = random.choice(distribution)
elif random.random() > 0.5:
new_config[key] = distribution[
max(0, distribution.index(config[key]) - 1)]
new_config[key] = distribution[max(
0,
distribution.index(config[key]) - 1)]
else:
new_config[key] = distribution[
min(len(distribution) - 1,
distribution.index(config[key]) + 1)]
new_config[key] = distribution[min(
len(distribution) - 1,
distribution.index(config[key]) + 1)]
else:
if random.random() < resample_probability:
new_config[key] = distribution()
@ -70,8 +69,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
print(
"[explore] perturbed config from {} -> {}".format(config, new_config))
print("[explore] perturbed config from {} -> {}".format(
config, new_config))
return new_config
@ -148,10 +147,13 @@ class PopulationBasedTraining(FIFOScheduler):
>>> run_experiments({...}, scheduler=pbt)
"""
def __init__(
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=60.0, hyperparam_mutations={},
resample_probability=0.25, custom_explore_fn=None):
def __init__(self,
time_attr="time_total_s",
reward_attr="episode_reward_mean",
perturbation_interval=60.0,
hyperparam_mutations={},
resample_probability=0.25,
custom_explore_fn=None):
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
@ -209,14 +211,13 @@ class PopulationBasedTraining(FIFOScheduler):
if not new_state.last_checkpoint:
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
return
new_config = explore(
trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability, self._custom_explore_fn)
print(
"[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability,
self._custom_explore_fn)
print("[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
trial.stop(stop_logger=False)
@ -242,9 +243,8 @@ class PopulationBasedTraining(FIFOScheduler):
if len(trials) <= 1:
return [], []
else:
return (
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))],
trials[int(math.floor(-len(trials) * PBT_QUANTILE)):])
def choose_trial_to_run(self, trial_runner):
"""Ensures all trials get fair share of time (as defined by time_attr).

View file

@ -14,7 +14,8 @@ ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR]
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
]
def register_trainable(name, trainable):
@ -32,8 +33,8 @@ def register_trainable(name, trainable):
if isinstance(trainable, FunctionType):
trainable = wrap_function(trainable)
if not issubclass(trainable, Trainable):
raise TypeError(
"Second argument must be convertable to Trainable", trainable)
raise TypeError("Second argument must be convertable to Trainable",
trainable)
_default_registry.register(TRAINABLE_CLASS, name, trainable)
@ -46,8 +47,7 @@ def register_env(name, env_creator):
"""
if not isinstance(env_creator, FunctionType):
raise TypeError(
"Second argument must be a function.", env_creator)
raise TypeError("Second argument must be a function.", env_creator)
_default_registry.register(ENV_CREATOR, name, env_creator)

View file

@ -4,8 +4,6 @@ from __future__ import print_function
from collections import namedtuple
import os
"""
When using ray.tune with custom training scripts, you must periodically report
training status back to Ray by calling reporter(result).
@ -18,73 +16,78 @@ In RLlib, the supplied algorithms fill in TrainingResult for you.
# Where ray.tune writes result files by default
DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results")
TrainingResult = namedtuple(
"TrainingResult",
[
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
TrainingResult = namedtuple("TrainingResult", [
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
# (Optional) If training is terminated.
"done",
# (Optional) If training is terminated.
"done",
# (Optional) Custom metadata to report for this iteration.
"info",
# (Optional) Custom metadata to report for this iteration.
"info",
# (Optional) The mean episode reward if applicable.
"episode_reward_mean",
# (Optional) The mean episode reward if applicable.
"episode_reward_mean",
# (Optional) The mean episode length if applicable.
"episode_len_mean",
# (Optional) The mean episode length if applicable.
"episode_len_mean",
# (Optional) The number of episodes total.
"episodes_total",
# (Optional) The number of episodes total.
"episodes_total",
# (Optional) The current training accuracy if applicable.
"mean_accuracy",
# (Optional) The current training accuracy if applicable.
"mean_accuracy",
# (Optional) The current validation accuracy if applicable.
"mean_validation_accuracy",
# (Optional) The current validation accuracy if applicable.
"mean_validation_accuracy",
# (Optional) The current training loss if applicable.
"mean_loss",
# (Optional) The current training loss if applicable.
"mean_loss",
# (Auto-filled) The negated current training loss.
"neg_mean_loss",
# (Auto-filled) The negated current training loss.
"neg_mean_loss",
# (Auto-filled) Unique string identifier for this experiment.
# This id is preserved across checkpoint / restore calls.
"experiment_id",
# (Auto-filled) Unique string identifier for this experiment. This id is
# preserved across checkpoint / restore calls.
"experiment_id",
# (Auto-filled) The index of this training iteration,
# e.g. call to train().
"training_iteration",
# (Auto-filled) The index of this training iteration, e.g. call to train().
"training_iteration",
# (Auto-filled) Number of timesteps in the simulator
# in this iteration.
"timesteps_this_iter",
# (Auto-filled) Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",
# (Auto-filled) Time in seconds this iteration took to run. This may
# be overriden in order to override the system-computed
# time difference.
"time_this_iter_s",
# (Auto-filled) Time in seconds this iteration took to run. This may be
# overriden in order to override the system-computed time difference.
"time_this_iter_s",
# (Auto-filled) Accumulated time in seconds for this entire experiment.
"time_total_s",
# (Auto-filled) Accumulated time in seconds for this entire experiment.
"time_total_s",
# (Auto-filled) The pid of the training process.
"pid",
# (Auto-filled) The pid of the training process.
"pid",
# (Auto-filled) A formatted date of when the result was processed.
"date",
# (Auto-filled) A formatted date of when the result was processed.
"date",
# (Auto-filled) A UNIX timestamp of when the result was processed.
"timestamp",
# (Auto-filled) A UNIX timestamp of when the result was processed.
"timestamp",
# (Auto-filled) The hostname of the machine hosting the
# training process.
"hostname",
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",
# (Auto-filled) The node ip of the machine hosting the
# training process.
"node_ip",
# (Auto-filled) The node ip of the machine hosting the training process.
"node_ip",
# (Auto=filled) The current hyperparameter configuration.
"config",
])
# (Auto=filled) The current hyperparameter configuration.
"config",
])
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
TrainingResult.__new__.__defaults__ = (None, ) * len(TrainingResult._fields)

View file

@ -20,7 +20,9 @@ if __name__ == "__main__":
run_experiments({
"test": {
"run": "my_class",
"stop": {"training_iteration": 1}
"stop": {
"training_iteration": 1
}
}
})
assert 'ray.rllib' not in sys.modules, "RLlib should not be imported"

View file

@ -60,163 +60,209 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"env": "CartPole-v0",
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"env": "CartPole-v0",
}
})
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"config": {"a": "b"},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"a": "b"
},
}
})
def testLogdir(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {"a": "b"},
}})
run_experiments({
"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a": "b"
},
}
})
def testLongFilename(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}})
run_experiments({
"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
},
}
})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments({"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}})
run_experiments({
"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}
})
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments({"foo": {
"run": grid_search("invalid grid search"),
}})
run_experiments({
"foo": {
"run": grid_search("invalid grid search"),
}
})
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments({"foo": {
"run": "asdf",
}})
run_experiments({
"foo": {
"run": "asdf",
}
})
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"stop": {"asdf": 1}
}})
run_experiments({"foo": {"run": "PPO", "stop": {"asdf": 1}}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"trial_resources": {"asdf": 1}
}})
run_experiments({
"foo": {
"run": "PPO",
"trial_resources": {
"asdf": 1
}
}
})
self.assertRaises(TuneError, f)
def testBadReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertRaises(TuneError, f)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testAbruptReturn(self):
def train(config, reporter):
reporter(timesteps_total=100)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
class RunExperimentTest(unittest.TestCase):
def setUp(self):
ray.init()
@ -228,6 +274,7 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
trials = run_experiments({
"foo": {
@ -251,13 +298,14 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
[trial] = run_experiments(exp1)
self.assertEqual(trial.status, Trial.TERMINATED)
@ -267,20 +315,21 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
exp2 = Experiment(**{
"name": "bar",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "bar",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
trials = run_experiments([exp1, exp2])
for trial in trials:
@ -306,9 +355,8 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(trials[0].trainable_name, "PPO")
self.assertEqual(trials[0].experiment_tag, "0")
self.assertEqual(trials[0].max_failures, 5)
self.assertEqual(
trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
self.assertEqual(trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
self.assertEqual(trials[1].experiment_tag, "1")
def testEval(self):
@ -392,11 +440,13 @@ class VariantGeneratorTest(unittest.TestCase):
trials = generate_trials({
"run": "PPO",
"config": {
"x": grid_search([
"x":
grid_search([
lambda spec: spec.config.y * 100,
lambda spec: spec.config.y * 200
]),
"y": lambda spec: 1,
"y":
lambda spec: 1,
},
})
trials = list(trials)
@ -406,12 +456,13 @@ class VariantGeneratorTest(unittest.TestCase):
def testRecursiveDep(self):
try:
list(generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
}))
list(
generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
}))
except RecursiveDependencyError as e:
assert "`foo` recursively depends on" in str(e), e
else:
@ -442,12 +493,15 @@ class TrialRunnerTest(unittest.TestCase):
register_trainable("f1", train)
experiments = {"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}}
experiments = {
"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
},
}
}
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
@ -468,12 +522,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=0, extra_cpu=3, extra_gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -489,12 +543,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -518,12 +572,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"stopping_criterion": {
"training_iteration": 5
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -547,13 +601,13 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trials = [
Trial("asdf", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -644,7 +698,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@ -675,7 +731,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@ -692,7 +750,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@ -721,14 +781,17 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"stopping_criterion": {
"training_iteration": 5
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
Trial("__fake", **kwargs)
]
for t in trials:
runner.add_trial(t)
runner.step()

View file

@ -19,9 +19,8 @@ _register_all()
def result(t, rew):
return TrainingResult(time_total_s=t,
episode_reward_mean=rew,
training_iteration=int(t))
return TrainingResult(
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
class EarlyStoppingSuite(unittest.TestCase):
@ -76,8 +75,7 @@ class EarlyStoppingSuite(unittest.TestCase):
rule.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
@ -89,8 +87,7 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t2, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
@ -124,8 +121,10 @@ class EarlyStoppingSuite(unittest.TestCase):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1,
time_attr='training_iteration', reward_attr='neg_mean_loss')
grace_period=0,
min_samples_required=1,
time_attr='training_iteration',
reward_attr='neg_mean_loss')
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
@ -185,7 +184,6 @@ class _MockTrialRunner():
class HyperbandSuite(unittest.TestCase):
def schedulerSetup(self, num_trials):
"""Setup a scheduler and Runner with max Iter = 9
@ -206,7 +204,10 @@ class HyperbandSuite(unittest.TestCase):
"""Default statistics for HyperBand"""
sched = HyperBandScheduler()
res = {
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
str(s): {
"n": sched._get_n0(s),
"r": sched._get_r0(s)
}
for s in range(sched._s_max_1)
}
res["max_trials"] = sum(v["n"] for v in res.values())
@ -298,8 +299,8 @@ class HyperbandSuite(unittest.TestCase):
# Provides results from 0 to 8 in order, keeping last one running
for i, trl in enumerate(trials):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
if i < current_length - 1:
self.assertEqual(action, TrialScheduler.PAUSE)
mock_runner.process_action(trl, action)
@ -321,8 +322,8 @@ class HyperbandSuite(unittest.TestCase):
# # Provides result in reverse order, killing the last one
cur_units = stats[str(1)]["r"]
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
@ -338,8 +339,8 @@ class HyperbandSuite(unittest.TestCase):
# # Provides result in reverse order, killing the last one
cur_units = stats[str(0)]["r"]
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
@ -354,14 +355,12 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_error(mock_runner, t3)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
@ -371,13 +370,14 @@ class HyperbandSuite(unittest.TestCase):
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
sched.on_trial_result(mock_runner, t, result(
stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
self.assertEqual(
len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testTrialEndedEarly(self):
"""Check successive halving happened even when one trial failed"""
@ -390,14 +390,12 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t3, result(1, 12))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial failed"""
@ -407,13 +405,14 @@ class HyperbandSuite(unittest.TestCase):
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
sched.on_trial_result(mock_runner, t, result(
stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
self.assertEqual(
len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testAddAfterHalving(self):
stats = self.default_statistics()
@ -426,8 +425,8 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
action = sched.on_trial_result(
mock_runner, t, result(init_units, i))
action = sched.on_trial_result(mock_runner, t, result(
init_units, i))
self.assertEqual(action, TrialScheduler.CONTINUE)
t = Trial("__fake")
sched.on_trial_add(None, t)
@ -435,13 +434,13 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
# Make sure that newly added trial gets fair computation (not just 1)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t,
result(init_units, 12)))
new_units = init_units + int(init_units * sched._eta)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t,
result(new_units, 12)))
def testAlternateMetrics(self):
"""Checking that alternate metrics will pass."""
@ -539,7 +538,6 @@ class _MockTrial(Trial):
class PopulationBasedTestingSuite(unittest.TestCase):
def basicSetup(self, resample_prob=0.0, explore=None):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
@ -554,9 +552,12 @@ class PopulationBasedTestingSuite(unittest.TestCase):
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
i,
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
"int_factor": 10})
i, {
"id_factor": i,
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10
})
runner.add_trial(trial)
trial.status = Trial.RUNNING
self.assertEqual(
@ -570,27 +571,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials = runner.get_trials()
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# checkpoint: both past interval and upper quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 1)
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(30, 201)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 2)
# not upper quantile any more
@ -608,8 +605,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
@ -617,8 +613,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertTrue("@perturbed" in trials[0].experiment_tag)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(pbt._num_perturbations, 1)
@ -627,8 +622,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(20, 40)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt._num_perturbations, 2)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
@ -662,7 +656,6 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(trials[0].config["const_factor"], 3)
def testPerturbationValues(self):
def assertProduces(fn, values):
random.seed(0)
seen = set()
@ -712,8 +705,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
TrialScheduler.PAUSE)
self.assertEqual(
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
def testSchedulesMostBehindTrialToRun(self):
@ -748,6 +740,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
new_config["id_factor"] = 42
new_config["float_factor"] = 43
return new_config
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
trials = runner.get_trials()
self.assertEqual(
@ -774,8 +767,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=1)
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
@ -803,8 +795,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
TrialScheduler.STOP)
def testAsyncHBAllCompletes(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=10)
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
trials = [Trial("PPO") for i in range(10)]
for t in trials:
scheduler.on_trial_add(None, t)
@ -834,8 +825,10 @@ class AsyncHyperBandSuite(unittest.TestCase):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
scheduler = AsyncHyperBandScheduler(
grace_period=1, time_attr='training_iteration',
reward_attr='neg_mean_loss', brackets=1)
grace_period=1,
time_attr='training_iteration',
reward_attr='neg_mean_loss',
brackets=1)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)

View file

@ -30,16 +30,15 @@ class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()
self.runner = TrialRunner(
launch_web_server=True, server_port=port)
self.runner = TrialRunner(launch_web_server=True, server_port=port)
runner = self.runner
kwargs = {
"stopping_criterion": {"training_iteration": 3},
"stopping_criterion": {
"training_iteration": 3
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
client = TuneClient("localhost:{}".format(port))
@ -61,7 +60,9 @@ class TuneServerSuite(unittest.TestCase):
runner.step()
spec = {
"run": "__fake",
"stop": {"training_iteration": 3},
"stop": {
"training_iteration": 3
},
"trial_resources": dict(cpu=1, gpu=1),
}
client.add_trial("test", spec)

View file

@ -114,8 +114,8 @@ class Trainable(object):
time_this_iter = time.time() - start
if result.timesteps_this_iter is None:
raise TuneError(
"Must specify timesteps_this_iter in result", result)
raise TuneError("Must specify timesteps_this_iter in result",
result)
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
@ -159,10 +159,10 @@ class Trainable(object):
"""
checkpoint_path = self._save(checkpoint_dir or self.logdir)
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".tune_metadata", "wb"))
pickle.dump([
self._experiment_id, self._iteration, self._timesteps_total,
self._time_total
], open(checkpoint_path + ".tune_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
@ -186,8 +186,10 @@ class Trainable(object):
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
"checkpoint_name":
os.path.basename(checkpoint_prefix),
"data":
data,
})
if len(compressed) > 10e6: # getting pretty large
print("Checkpoint size is {} bytes".format(len(compressed)))

View file

@ -42,12 +42,12 @@ class Resources(
__slots__ = ()
def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0):
return super(Resources, cls).__new__(
cls, cpu, gpu, extra_cpu, extra_gpu)
return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu,
extra_gpu)
def summary_string(self):
return "{} CPUs, {} GPUs".format(
self.cpu + self.extra_cpu, self.gpu + self.extra_gpu)
return "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
self.gpu + self.extra_gpu)
def cpu_total(self):
return self.cpu + self.extra_cpu
@ -77,11 +77,17 @@ class Trial(object):
TERMINATED = "TERMINATED"
ERROR = "ERROR"
def __init__(
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
experiment_tag="", resources=Resources(cpu=1, gpu=0),
stopping_criterion=None, checkpoint_freq=0,
restore_path=None, upload_dir=None, max_failures=0):
def __init__(self,
trainable_name,
config=None,
local_dir=DEFAULT_RESULTS_DIR,
experiment_tag="",
resources=Resources(cpu=1, gpu=0),
stopping_criterion=None,
checkpoint_freq=0,
restore_path=None,
upload_dir=None,
max_failures=0):
"""Initialize a new trial.
The args here take the same meaning as the command line flags defined
@ -166,19 +172,20 @@ class Trial(object):
try:
if error_msg and self.logdir:
self.num_failures += 1
error_file = os.path.join(
self.logdir, "error_{}.txt".format(date_str()))
error_file = os.path.join(self.logdir, "error_{}.txt".format(
date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
self.error_file = error_file
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
stop_tasks.append(self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
stop_tasks.append(
self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
stop_tasks, num_returns=2, timeout=250)
except Exception:
print("Error stopping runner:", traceback.format_exc())
self.status = Trial.ERROR
@ -252,12 +259,12 @@ class Trial(object):
return '{} pid={}'.format(hostname, pid)
pieces = [
'{} [{}]'.format(
self._status_string(),
location_string(
self.last_result.hostname, self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)),
'{} ts'.format(int(self.last_result.timesteps_total))]
'{} [{}]'.format(self._status_string(),
location_string(self.last_result.hostname,
self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format(
int(self.last_result.timesteps_total))
]
if self.last_result.episode_reward_mean is not None:
pieces.append('{} rew'.format(
@ -274,10 +281,8 @@ class Trial(object):
return ', '.join(pieces)
def _status_string(self):
return "{}{}".format(
self.status,
", {} failures: {}".format(self.num_failures, self.error_file)
if self.error_file else "")
return "{}{}".format(self.status, ", {} failures: {}".format(
self.num_failures, self.error_file) if self.error_file else "")
def has_checkpoint(self):
return self._checkpoint_path is not None or \
@ -335,9 +340,8 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
if self.verbose and (
terminate or
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
if self.verbose and (terminate or time.time() - self.last_debug >
DEBUG_PRINT_INTERVAL):
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
@ -358,8 +362,8 @@ class Trial(object):
prefix="{}_{}".format(
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
dir=self.local_dir)
self.result_logger = UnifiedLogger(
self.config, self.logdir, self.upload_dir)
self.result_logger = UnifiedLogger(self.config, self.logdir,
self.upload_dir)
remote_logdir = self.logdir
def logger_creator(config):
@ -372,7 +376,8 @@ class Trial(object):
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
self.runner = cls.remote(
config=self.config, registry=ray.tune.registry.get_registry(),
config=self.config,
registry=ray.tune.registry.get_registry(),
logger_creator=logger_creator)
def set_verbose(self, verbose):
@ -387,8 +392,8 @@ class Trial(object):
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
identifier = "{}_{}".format(self.trainable_name,
self.config["env"])
else:
identifier = self.trainable_name
if self.experiment_tag:

View file

@ -13,7 +13,6 @@ from ray.tune.web_server import TuneServer
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
MAX_DEBUG_TRIALS = 20
@ -39,8 +38,11 @@ class TrialRunner(object):
misleading benchmark results.
"""
def __init__(self, scheduler=None, launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT, verbose=True):
def __init__(self,
scheduler=None,
launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
"""Initializes a new TrialRunner.
Args:
@ -73,9 +75,8 @@ class TrialRunner(object):
"""Returns whether all trials have finished running."""
if self._total_time > self._global_time_limit:
print(
"Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
print("Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
return True
for t in self._trials:
@ -98,12 +99,12 @@ class TrialRunner(object):
for trial in self._trials:
if trial.status == Trial.PENDING:
if not self.has_resources(trial.resources):
raise TuneError((
"Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster only has {} "
"available.").format(
trial.resources.summary_string(),
self._avail_resources.summary_string()))
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster only has {} "
"available.").format(
trial.resources.summary_string(),
self._avail_resources.summary_string()))
elif trial.status == Trial.PAUSED:
raise TuneError(
"There are paused trials, but no more pending "
@ -165,24 +166,20 @@ class TrialRunner(object):
for state, trials in sorted(states.items()):
limit = limit_per_state[state]
messages.append("{} trials:".format(state))
for t in sorted(
trials, key=lambda t: t.experiment_tag)[:limit]:
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
messages.append(" - {}:\t{}".format(t, t.progress_string()))
if len(trials) > limit:
messages.append(" ... {} more not shown".format(
len(trials) - limit))
messages.append(
" ... {} more not shown".format(len(trials) - limit))
return "\n".join(messages) + "\n"
def _debug_messages(self):
messages = ["== Status =="]
messages.append(self._scheduler_alg.debug_string())
if self._resources_initialized:
messages.append(
"Resources used: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu,
self._avail_resources.cpu,
self._committed_resources.gpu,
self._avail_resources.gpu))
messages.append("Resources used: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu, self._avail_resources.cpu,
self._committed_resources.gpu, self._avail_resources.gpu))
return messages
def has_resources(self, resources):
@ -190,9 +187,8 @@ class TrialRunner(object):
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
return (
resources.cpu_total() <= cpu_avail and
resources.gpu_total() <= gpu_avail)
return (resources.cpu_total() <= cpu_avail
and resources.gpu_total() <= gpu_avail)
def _get_next_trial(self):
self._update_avail_resources()
@ -307,8 +303,9 @@ class TrialRunner(object):
self._scheduler_alg.on_trial_remove(self, trial)
elif trial.status is Trial.RUNNING:
# NOTE: There should only be one...
result_id = [rid for rid, t in self._running.items()
if t is trial][0]
result_id = [
rid for rid, t in self._running.items() if t is trial
][0]
self._running.pop(result_id)
try:
result = ray.get(result_id)
@ -339,9 +336,8 @@ class TrialRunner(object):
def _update_avail_resources(self):
clients = ray.global_state.client_table()
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry['ClientType'] == 'local_scheduler' and not
entry['Deleted'])
entry for client in clients.values() for entry in client if
(entry['ClientType'] == 'local_scheduler' and not entry['Deleted'])
]
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)

View file

@ -99,12 +99,12 @@ class FIFOScheduler(TrialScheduler):
def choose_trial_to_run(self, trial_runner):
for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
return trial
for trial in trial_runner.get_trials():
if (trial.status == Trial.PAUSED and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PAUSED
and trial_runner.has_resources(trial.resources)):
return trial
return None

View file

@ -16,7 +16,6 @@ from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.experiment import Experiment
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
@ -30,13 +29,15 @@ def _make_scheduler(args):
if args.scheduler in _SCHEDULERS:
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
else:
raise TuneError(
"Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, verbose=True):
def run_experiments(experiments,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
"""Tunes experiments.
Args:
@ -54,17 +55,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
scheduler = FIFOScheduler()
runner = TrialRunner(
scheduler, launch_web_server=with_server, server_port=server_port,
scheduler,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose)
exp_list = experiments
if isinstance(experiments, Experiment):
exp_list = [experiments]
elif type(experiments) is dict:
exp_list = [Experiment.from_json(name, spec)
for name, spec in experiments.items()]
exp_list = [
Experiment.from_json(name, spec)
for name, spec in experiments.items()
]
if (type(exp_list) is list and
all(isinstance(exp, Experiment) for exp in exp_list)):
if (type(exp_list) is list
and all(isinstance(exp, Experiment) for exp in exp_list)):
for experiment in exp_list:
scheduler.add_experiment(experiment, runner)
else:

View file

@ -7,7 +7,6 @@ import base64
import ray
from ray.tune.registry import _to_pinnable, _from_pinnable
_pinned_objects = []
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
@ -15,14 +14,15 @@ PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
def pin_in_object_store(obj):
obj_id = ray.put(_to_pinnable(obj))
_pinned_objects.append(ray.get(obj_id))
return "{}{}".format(
PINNED_OBJECT_PREFIX, base64.b64encode(obj_id.id()).decode("utf-8"))
return "{}{}".format(PINNED_OBJECT_PREFIX,
base64.b64encode(obj_id.id()).decode("utf-8"))
def get_pinned_object(pinned_id):
from ray.local_scheduler import ObjectID
return _from_pinnable(ray.get(ObjectID(
base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
return _from_pinnable(
ray.get(
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
if __name__ == '__main__':

View file

@ -163,8 +163,8 @@ def _generate_variants(spec):
for path, value in grid_vars:
resolved_vars[path] = _get_value(spec, path)
for k, v in resolved.items():
if (k in resolved_vars and v != resolved_vars[k] and
_is_resolved(resolved_vars[k])):
if (k in resolved_vars and v != resolved_vars[k]
and _is_resolved(resolved_vars[k])):
raise ValueError(
"The variable `{}` could not be unambiguously "
"resolved to a single value. Consider simplifying "
@ -262,16 +262,16 @@ def _unresolved_values(spec):
for k, v in spec.items():
resolved, v = _try_resolve(v)
if not resolved:
found[(k,)] = v
found[(k, )] = v
elif isinstance(v, dict):
# Recurse into a dict
for (path, value) in _unresolved_values(v).items():
found[(k,) + path] = value
found[(k, ) + path] = value
elif isinstance(v, list):
# Recurse into a list
for i, elem in enumerate(v):
for (path, value) in _unresolved_values({i: elem}).items():
found[(k,) + path] = value
found[(k, ) + path] = value
return found

View file

@ -61,8 +61,10 @@ def _resolve(directory, result_fname):
def load_results_to_df(directory, result_name="result.json"):
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
for f in files if f == result_name]
exp_directories = [
dirpath for dirpath, dirs, files in os.walk(directory) for f in files
if f == result_name
]
data = [_resolve(d, result_name) for d in exp_directories]
data = [d for d in data if d]
return pd.DataFrame(data)
@ -76,8 +78,9 @@ def generate_plotly_dim_dict(df, field):
dim_dict["values"] = column
elif is_string_dtype(column):
texts = column.unique()
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0]
for x in column]
dim_dict["values"] = [
np.argwhere(texts == x).flatten()[0] for x in column
]
dim_dict["tickvals"] = list(range(len(texts)))
dim_dict["ticktext"] = texts
else:

View file

@ -39,28 +39,30 @@ class TuneClient(object):
def get_all_trials(self):
"""Returns a list of all trials (trial_id, config, status)."""
return self._get_response(
{"command": TuneClient.GET_LIST})
return self._get_response({"command": TuneClient.GET_LIST})
def get_trial(self, trial_id):
"""Returns the last result for queried trial."""
return self._get_response(
{"command": TuneClient.GET_TRIAL,
"trial_id": trial_id})
return self._get_response({
"command": TuneClient.GET_TRIAL,
"trial_id": trial_id
})
def add_trial(self, name, trial_spec):
"""Adds a trial of `name` with configurations."""
# TODO(rliaw): have better way of specifying a new trial
return self._get_response(
{"command": TuneClient.ADD,
"name": name,
"spec": trial_spec})
return self._get_response({
"command": TuneClient.ADD,
"name": name,
"spec": trial_spec
})
def stop_trial(self, trial_id):
"""Requests to stop trial."""
return self._get_response(
{"command": TuneClient.STOP,
"trial_id": trial_id})
return self._get_response({
"command": TuneClient.STOP,
"trial_id": trial_id
})
def _get_response(self, data):
payload = json.dumps(data).encode()
@ -71,7 +73,6 @@ class TuneClient(object):
def RunnerHandler(runner):
class Handler(SimpleHTTPRequestHandler):
def do_GET(self):
content_len = int(self.headers.get('Content-Length'), 0)
raw_body = self.rfile.read(content_len)
@ -82,8 +83,7 @@ def RunnerHandler(runner):
else:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(
response).encode())
self.wfile.write(json.dumps(response).encode())
def trial_info(self, trial):
if trial.last_result:
@ -112,8 +112,9 @@ def RunnerHandler(runner):
response = {}
try:
if command == TuneClient.GET_LIST:
response["trials"] = [self.trial_info(t)
for t in runner.get_trials()]
response["trials"] = [
self.trial_info(t) for t in runner.get_trials()
]
elif command == TuneClient.GET_TRIAL:
trial = get_trial()
response["trial_info"] = self.trial_info(trial)
@ -147,8 +148,7 @@ class TuneServer(threading.Thread):
self._port = port if port else self.DEFAULT_PORT
address = ('localhost', self._port)
print("Starting Tune Server...")
self._server = HTTPServer(
address, RunnerHandler(runner))
self._server = HTTPServer(address, RunnerHandler(runner))
self.start()
def run(self):

View file

@ -46,7 +46,10 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
def push_error_to_driver(redis_client,
error_type,
message,
driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
@ -64,9 +67,11 @@ def push_error_to_driver(redis_client, error_type, message, driver_id=None,
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
redis_client.hmset(error_key, {"type": error_type,
"message": message,
"data": data})
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
@ -134,10 +139,8 @@ def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
FunctionProperties = collections.namedtuple("FunctionProperties",
["num_return_vals",
"resources",
"max_calls"])
FunctionProperties = collections.namedtuple(
"FunctionProperties", ["num_return_vals", "resources", "max_calls"])
"""FunctionProperties: A named tuple storing remote functions information."""

File diff suppressed because it is too large Load diff

View file

@ -8,34 +8,51 @@ import traceback
import ray
import ray.actor
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument("--node-ip-address", required=True, type=str,
help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str,
help="the object store's name")
parser.add_argument("--object-store-manager-name", required=False, type=str,
help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=False, type=str,
help="the local scheduler's name")
parser.add_argument("--raylet-name", required=False, type=str,
help="the raylet's name")
parser = argparse.ArgumentParser(
description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument(
"--node-ip-address",
required=True,
type=str,
help="the ip address of the worker's node")
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--object-store-name",
required=True,
type=str,
help="the object store's name")
parser.add_argument(
"--object-store-manager-name",
required=False,
type=str,
help="the object store manager's name")
parser.add_argument(
"--local-scheduler-name",
required=False,
type=str,
help="the local scheduler's name")
parser.add_argument(
"--raylet-name", required=False, type=str, help="the raylet's name")
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name,
"raylet_socket_name": args.raylet_name}
info = {
"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name,
"raylet_socket_name": args.raylet_name
}
ray.worker.connect(info, mode=ray.WORKER_MODE,
use_raylet=(args.raylet_name is not None))
ray.worker.connect(
info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None))
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
@ -54,8 +71,8 @@ if __name__ == "__main__":
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = ray.services.create_redis_client(args.redis_address)
ray.utils.push_error_to_driver(redis_client, "worker_crash",
traceback_str, driver_id=None)
ray.utils.push_error_to_driver(
redis_client, "worker_crash", traceback_str, driver_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to

View file

@ -18,13 +18,11 @@ import setuptools.command.build_ext as _build_ext
ray_files = [
"ray/core/src/common/thirdparty/redis/src/redis-server",
"ray/core/src/common/redis_module/libray_redis_module.so",
"ray/core/src/plasma/plasma_store",
"ray/core/src/plasma/plasma_manager",
"ray/core/src/plasma/plasma_store", "ray/core/src/plasma/plasma_manager",
"ray/core/src/local_scheduler/local_scheduler",
"ray/core/src/local_scheduler/liblocal_scheduler_library.so",
"ray/core/src/global_scheduler/global_scheduler",
"ray/core/src/ray/raylet/raylet_monitor",
"ray/core/src/ray/raylet/raylet",
"ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet",
"ray/WebUI.ipynb"
]
@ -35,14 +33,14 @@ ray_ui_files = [
"ray/core/src/catapult_files/trace_viewer_full.html"
]
ray_autoscaler_files = [
"ray/autoscaler/aws/example-full.yaml"
]
ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"]
if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
ray_files += ["ray/core/src/credis/build/src/libmember.so",
"ray/core/src/credis/build/src/libmaster.so",
"ray/core/src/credis/redis/src/redis-server"]
ray_files += [
"ray/core/src/credis/build/src/libmember.so",
"ray/core/src/credis/build/src/libmaster.so",
"ray/core/src/credis/redis/src/redis-server"
]
# The UI files are mandatory if the INCLUDE_UI environment variable equals 1.
# Otherwise, they are optional.
@ -54,9 +52,8 @@ else:
optional_ray_files += ray_autoscaler_files
extras = {
"rllib": [
"tensorflow", "pyyaml", "gym[atari]", "opencv-python",
"lz4", "scipy"]
"rllib":
["tensorflow", "pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"]
}
@ -73,8 +70,9 @@ class build_ext(_build_ext.build_ext):
pyarrow_files = [
os.path.join("ray/pyarrow_files/pyarrow", filename)
for filename in os.listdir("./ray/pyarrow_files/pyarrow")
if not os.path.isdir(os.path.join("ray/pyarrow_files/pyarrow",
filename))]
if not os.path.isdir(
os.path.join("ray/pyarrow_files/pyarrow", filename))
]
files_to_include = ray_files + pyarrow_files
@ -84,8 +82,8 @@ class build_ext(_build_ext.build_ext):
generated_python_directory = "ray/core/generated"
for filename in os.listdir(generated_python_directory):
if filename[-3:] == ".py":
self.move_file(os.path.join(generated_python_directory,
filename))
self.move_file(
os.path.join(generated_python_directory, filename))
# Try to copy over the optional files.
for filename in optional_ray_files:
@ -114,27 +112,30 @@ class BinaryDistribution(Distribution):
return True
setup(name="ray",
# The version string is also in __init__.py. TODO(pcm): Fix this.
version="0.4.0",
packages=find_packages(),
cmdclass={"build_ext": build_ext},
# The BinaryDistribution argument triggers build_ext.
distclass=BinaryDistribution,
install_requires=["numpy",
"funcsigs",
"click",
"colorama",
"psutil",
"pytest",
"pyyaml",
"redis",
# The six module is required by pyarrow.
"six >= 1.0.0",
"flatbuffers"],
setup_requires=["cython >= 0.23"],
extras_require=extras,
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
include_package_data=True,
zip_safe=False,
license="Apache 2.0")
setup(
name="ray",
# The version string is also in __init__.py. TODO(pcm): Fix this.
version="0.4.0",
packages=find_packages(),
cmdclass={"build_ext": build_ext},
# The BinaryDistribution argument triggers build_ext.
distclass=BinaryDistribution,
install_requires=[
"numpy",
"funcsigs",
"click",
"colorama",
"psutil",
"pytest",
"pyyaml",
"redis",
# The six module is required by pyarrow.
"six >= 1.0.0",
"flatbuffers"
],
setup_requires=["cython >= 0.23"],
extras_require=extras,
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
include_package_data=True,
zip_safe=False,
license="Apache 2.0")

View file

@ -15,7 +15,6 @@ import ray.test.test_utils
class ActorAPI(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -39,20 +38,22 @@ class ActorAPI(unittest.TestCase):
self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, "ab"))
actor = Actor.remote(1, 2, "c")
self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")),
(3, 5, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, "cd"))
actor = Actor.remote(1, arg2="c")
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")),
(1, 3, "cd"))
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
(1, 1, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(0, arg2="d")), (1, 3, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
(1, 1, "cd"))
actor = Actor.remote(1, arg2="c", arg1=2)
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")),
(1, 4, "cd"))
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
(1, 2, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(0, arg2="d")), (1, 4, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
(1, 2, "cd"))
# Make sure we get an exception if the constructor is called
# incorrectly.
@ -84,16 +85,18 @@ class ActorAPI(unittest.TestCase):
self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, (), ()))
actor = Actor.remote(1, 2)
self.assertEqual(ray.get(actor.get_values.remote(2, 3)),
(3, 5, (), ()))
self.assertEqual(
ray.get(actor.get_values.remote(2, 3)), (3, 5, (), ()))
actor = Actor.remote(1, 2, "c")
self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")),
(3, 5, ("c",), ("d",)))
self.assertEqual(
ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, ("c", ),
("d", )))
actor = Actor.remote(1, 2, "a", "b", "c", "d")
self.assertEqual(ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)),
(3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)))
self.assertEqual(
ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)),
(3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)))
@ray.remote
class Actor(object):
@ -106,7 +109,7 @@ class ActorAPI(unittest.TestCase):
a = Actor.remote()
self.assertEqual(ray.get(a.get_values.remote()), ((), ()))
a = Actor.remote(1)
self.assertEqual(ray.get(a.get_values.remote(2)), ((1,), (2,)))
self.assertEqual(ray.get(a.get_values.remote(2)), ((1, ), (2, )))
a = Actor.remote(1, 2)
self.assertEqual(ray.get(a.get_values.remote(3, 4)), ((1, 2), (3, 4)))
@ -191,6 +194,7 @@ class ActorAPI(unittest.TestCase):
# This is an invalid way of using the actor decorator.
with self.assertRaises(Exception):
@ray.remote()
class Actor(object):
def __init__(self):
@ -198,6 +202,7 @@ class ActorAPI(unittest.TestCase):
# This is an invalid way of using the actor decorator.
with self.assertRaises(Exception):
@ray.remote(invalid_kwarg=0) # noqa: F811
class Actor(object):
def __init__(self):
@ -205,6 +210,7 @@ class ActorAPI(unittest.TestCase):
# This is an invalid way of using the actor decorator.
with self.assertRaises(Exception):
@ray.remote(num_cpus=0, invalid_kwarg=0) # noqa: F811
class Actor(object):
def __init__(self):
@ -300,7 +306,6 @@ class ActorAPI(unittest.TestCase):
class ActorMethods(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -417,8 +422,9 @@ class ActorMethods(unittest.TestCase):
results = []
# Call each actor's method a bunch of times.
for i in range(num_actors):
results += [actors[i].increase.remote()
for _ in range(num_increases)]
results += [
actors[i].increase.remote() for _ in range(num_increases)
]
result_values = ray.get(results)
for i in range(num_actors):
self.assertEqual(
@ -440,7 +446,6 @@ class ActorMethods(unittest.TestCase):
class ActorNesting(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -510,6 +515,7 @@ class ActorNesting(unittest.TestCase):
def get_value(self):
return self.x
self.actor2 = Actor2.remote(z)
def get_values(self, z):
@ -556,12 +562,14 @@ class ActorNesting(unittest.TestCase):
def get_value(self):
return self.x
actor = Actor1.remote(x)
return ray.get([actor.get_value.remote() for _ in range(n)])
self.assertEqual(ray.get(f.remote(3, 1)), [3])
self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]),
[20 * [i] for i in range(10)])
self.assertEqual(
ray.get([f.remote(i, 20) for i in range(10)]),
[20 * [i] for i in range(10)])
def testUseActorWithinRemoteFunction(self):
# Make sure we can create and use actors within remote funtions.
@ -591,6 +599,7 @@ class ActorNesting(unittest.TestCase):
# Export a bunch of remote functions.
num_remote_functions = 50
for i in range(num_remote_functions):
@ray.remote
def f():
return i
@ -613,7 +622,6 @@ class ActorNesting(unittest.TestCase):
class ActorInheritance(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -646,7 +654,6 @@ class ActorInheritance(unittest.TestCase):
class ActorSchedulingProperties(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -674,7 +681,6 @@ class ActorSchedulingProperties(unittest.TestCase):
class ActorsOnMultipleNodes(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -692,8 +698,10 @@ class ActorsOnMultipleNodes(unittest.TestCase):
def testActorLoadBalancing(self):
num_local_schedulers = 3
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers)
ray.worker._init(
start_ray_local=True,
num_workers=0,
num_local_schedulers=num_local_schedulers)
@ray.remote
class Actor1(object):
@ -712,13 +720,13 @@ class ActorsOnMultipleNodes(unittest.TestCase):
attempts = 0
while attempts < num_attempts:
actors = [Actor1.remote() for _ in range(num_actors)]
locations = ray.get([actor.get_location.remote()
for actor in actors])
locations = ray.get(
[actor.get_location.remote() for actor in actors])
names = set(locations)
counts = [locations.count(name) for name in names]
print("Counts are {}.".format(counts))
if (len(names) == num_local_schedulers and
all([count >= minimum_count for count in counts])):
if (len(names) == num_local_schedulers
and all([count >= minimum_count for count in counts])):
break
attempts += 1
self.assertLess(attempts, num_attempts)
@ -732,18 +740,17 @@ class ActorsOnMultipleNodes(unittest.TestCase):
class ActorsWithGPUs(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Crashing with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Crashing with new GCS API.")
def testActorGPUs(self):
num_local_schedulers = 3
num_gpus_per_scheduler = 4
ray.worker._init(
start_ray_local=True, num_workers=0,
start_ray_local=True,
num_workers=0,
num_local_schedulers=num_local_schedulers,
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ -760,19 +767,21 @@ class ActorsWithGPUs(unittest.TestCase):
tuple(self.gpu_ids))
# Create one actor per GPU.
actors = [Actor1.remote() for _
in range(num_local_schedulers * num_gpus_per_scheduler)]
actors = [
Actor1.remote()
for _ in range(num_local_schedulers * num_gpus_per_scheduler)
]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
for actor in actors])
locations_and_ids = ray.get(
[actor.get_location_and_ids.remote() for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), num_local_schedulers)
location_actor_combinations = []
for node_name in node_names:
for gpu_id in range(num_gpus_per_scheduler):
location_actor_combinations.append((node_name, (gpu_id,)))
self.assertEqual(set(locations_and_ids),
set(location_actor_combinations))
location_actor_combinations.append((node_name, (gpu_id, )))
self.assertEqual(
set(locations_and_ids), set(location_actor_combinations))
# Creating a new actor should fail because all of the GPUs are being
# used.
@ -784,7 +793,8 @@ class ActorsWithGPUs(unittest.TestCase):
num_local_schedulers = 3
num_gpus_per_scheduler = 5
ray.worker._init(
start_ray_local=True, num_workers=0,
start_ray_local=True,
num_workers=0,
num_local_schedulers=num_local_schedulers,
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ -803,8 +813,8 @@ class ActorsWithGPUs(unittest.TestCase):
# Create some actors.
actors1 = [Actor1.remote() for _ in range(num_local_schedulers * 2)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
for actor in actors1])
locations_and_ids = ray.get(
[actor.get_location_and_ids.remote() for actor in actors1])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), num_local_schedulers)
@ -835,11 +845,11 @@ class ActorsWithGPUs(unittest.TestCase):
# Create some actors.
actors2 = [Actor2.remote() for _ in range(num_local_schedulers)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
for actor in actors2])
self.assertEqual(node_names,
set([location for location, gpu_id
in locations_and_ids]))
locations_and_ids = ray.get(
[actor.get_location_and_ids.remote() for actor in actors2])
self.assertEqual(
node_names,
set([location for location, gpu_id in locations_and_ids]))
for location, gpu_ids in locations_and_ids:
gpus_in_use[location].extend(gpu_ids)
for node_name in node_names:
@ -855,9 +865,12 @@ class ActorsWithGPUs(unittest.TestCase):
def testActorDifferentNumbersOfGPUs(self):
# Test that we can create actors on two nodes that have different
# numbers of GPUs.
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=3, num_cpus=[10, 10, 10],
num_gpus=[0, 5, 10])
ray.worker._init(
start_ray_local=True,
num_workers=0,
num_local_schedulers=3,
num_cpus=[10, 10, 10],
num_gpus=[0, 5, 10])
@ray.remote(num_gpus=1)
class Actor1(object):
@ -872,16 +885,19 @@ class ActorsWithGPUs(unittest.TestCase):
# Create some actors.
actors = [Actor1.remote() for _ in range(0 + 5 + 10)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
for actor in actors])
locations_and_ids = ray.get(
[actor.get_location_and_ids.remote() for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), 2)
for node_name in node_names:
node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids
if location == node_name]
node_gpu_ids = [
gpu_id for location, gpu_id in locations_and_ids
if location == node_name
]
self.assertIn(len(node_gpu_ids), [5, 10])
self.assertEqual(set(node_gpu_ids),
set([(i,) for i in range(len(node_gpu_ids))]))
self.assertEqual(
set(node_gpu_ids),
set([(i, ) for i in range(len(node_gpu_ids))]))
# Creating a new actor should fail because all of the GPUs are being
# used.
@ -893,8 +909,10 @@ class ActorsWithGPUs(unittest.TestCase):
num_local_schedulers = 10
num_gpus_per_scheduler = 10
ray.worker._init(
start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers, redirect_output=True,
start_ray_local=True,
num_workers=0,
num_local_schedulers=num_local_schedulers,
redirect_output=True,
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ -906,15 +924,17 @@ class ActorsWithGPUs(unittest.TestCase):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ((ray.worker.global_worker.plasma_client
.store_socket_name),
tuple(self.gpu_ids))
return ((ray.worker.global_worker.plasma_client.
store_socket_name), tuple(self.gpu_ids))
# Create n actors.
for _ in range(n):
Actor.remote()
ray.get([create_actors.remote(num_gpus_per_scheduler)
for _ in range(num_local_schedulers)])
ray.get([
create_actors.remote(num_gpus_per_scheduler)
for _ in range(num_local_schedulers)
])
@ray.remote(num_gpus=1)
class Actor(object):
@ -936,7 +956,8 @@ class ActorsWithGPUs(unittest.TestCase):
num_local_schedulers = 3
num_gpus_per_scheduler = 6
ray.worker._init(
start_ray_local=True, num_workers=0,
start_ray_local=True,
num_workers=0,
num_local_schedulers=num_local_schedulers,
num_cpus=num_gpus_per_scheduler,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ -951,11 +972,11 @@ class ActorsWithGPUs(unittest.TestCase):
self.assertLess(first_interval[0], first_interval[1])
self.assertLess(second_interval[0], second_interval[1])
intervals_nonoverlapping = (
first_interval[1] <= second_interval[0] or
second_interval[1] <= first_interval[0])
first_interval[1] <= second_interval[0]
or second_interval[1] <= first_interval[0])
assert intervals_nonoverlapping, (
"Intervals {} and {} are overlapping."
.format(first_interval, second_interval))
"Intervals {} and {} are overlapping.".format(
first_interval, second_interval))
@ray.remote(num_gpus=1)
def f1():
@ -995,13 +1016,16 @@ class ActorsWithGPUs(unittest.TestCase):
def locations_to_intervals_for_many_tasks():
# Launch a bunch of GPU tasks.
locations_ids_and_intervals = ray.get(
[f1.remote() for _
in range(5 * num_local_schedulers * num_gpus_per_scheduler)] +
[f2.remote() for _
in range(5 * num_local_schedulers * num_gpus_per_scheduler)] +
[f1.remote() for _
in range(5 * num_local_schedulers * num_gpus_per_scheduler)])
locations_ids_and_intervals = ray.get([
f1.remote() for _ in range(
5 * num_local_schedulers * num_gpus_per_scheduler)
] + [
f2.remote() for _ in range(
5 * num_local_schedulers * num_gpus_per_scheduler)
] + [
f1.remote() for _ in range(
5 * num_local_schedulers * num_gpus_per_scheduler)
])
locations_to_intervals = collections.defaultdict(lambda: [])
for location, gpu_ids, interval in locations_ids_and_intervals:
@ -1012,8 +1036,9 @@ class ActorsWithGPUs(unittest.TestCase):
# Run a bunch of GPU tasks.
locations_to_intervals = locations_to_intervals_for_many_tasks()
# Make sure that all GPUs were used.
self.assertEqual(len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler)
self.assertEqual(
len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler)
# For each GPU, verify that the set of tasks that used this specific
# GPU did not overlap in time.
for locations in locations_to_intervals:
@ -1030,8 +1055,9 @@ class ActorsWithGPUs(unittest.TestCase):
# Run a bunch of GPU tasks.
locations_to_intervals = locations_to_intervals_for_many_tasks()
# Make sure that all but one of the GPUs were used.
self.assertEqual(len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler - 1)
self.assertEqual(
len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler - 1)
# For each GPU, verify that the set of tasks that used this specific
# GPU did not overlap in time.
for locations in locations_to_intervals:
@ -1041,14 +1067,15 @@ class ActorsWithGPUs(unittest.TestCase):
# Create several more actors that use GPUs.
actors = [Actor1.remote() for _ in range(3)]
actor_locations = ray.get([actor.get_location_and_ids.remote()
for actor in actors])
actor_locations = ray.get(
[actor.get_location_and_ids.remote() for actor in actors])
# Run a bunch of GPU tasks.
locations_to_intervals = locations_to_intervals_for_many_tasks()
# Make sure that all but 11 of the GPUs were used.
self.assertEqual(len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler - 1 - 3)
self.assertEqual(
len(locations_to_intervals),
num_local_schedulers * num_gpus_per_scheduler - 1 - 3)
# For each GPU, verify that the set of tasks that used this specific
# GPU did not overlap in time.
for locations in locations_to_intervals:
@ -1059,9 +1086,10 @@ class ActorsWithGPUs(unittest.TestCase):
self.assertNotIn(location, locations_to_intervals)
# Create more actors to fill up all the GPUs.
more_actors = [Actor1.remote() for _ in
range(num_local_schedulers *
num_gpus_per_scheduler - 1 - 3)]
more_actors = [
Actor1.remote() for _ in range(
num_local_schedulers * num_gpus_per_scheduler - 1 - 3)
]
# Wait for the actors to finish being created.
ray.get([actor.get_location_and_ids.remote() for actor in more_actors])
@ -1195,16 +1223,17 @@ class ActorsWithGPUs(unittest.TestCase):
class ActorReconstruction(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testLocalSchedulerDying(self):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=True)
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_workers=0,
redirect_output=True)
@ray.remote
class Counter(object):
@ -1243,8 +1272,7 @@ class ActorReconstruction(unittest.TestCase):
self.assertEqual(results, list(range(1, 1 + len(results))))
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testManyLocalSchedulersDying(self):
# This test can be made more stressful by increasing the numbers below.
# The total number of actors created will be
@ -1253,9 +1281,11 @@ class ActorReconstruction(unittest.TestCase):
num_actors_at_a_time = 3
num_function_calls_at_a_time = 10
ray.worker._init(start_ray_local=True,
num_local_schedulers=num_local_schedulers,
num_workers=0, redirect_output=True)
ray.worker._init(
start_ray_local=True,
num_local_schedulers=num_local_schedulers,
num_workers=0,
redirect_output=True)
@ray.remote
class SlowCounter(object):
@ -1281,14 +1311,13 @@ class ActorReconstruction(unittest.TestCase):
# a local scheduler, and run some more methods.
for i in range(num_local_schedulers - 1):
# Create some actors.
actors.extend([SlowCounter.remote()
for _ in range(num_actors_at_a_time)])
actors.extend(
[SlowCounter.remote() for _ in range(num_actors_at_a_time)])
# Run some methods.
for j in range(len(actors)):
actor = actors[j]
for _ in range(num_function_calls_at_a_time):
result_ids[actor].append(
actor.inc.remote(j ** 2 * 0.000001))
result_ids[actor].append(actor.inc.remote(j**2 * 0.000001))
# Kill a plasma store to get rid of the cached objects and trigger
# exit of the corresponding local scheduler. Don't kill the first
# local scheduler since that is the one that the driver is
@ -1302,18 +1331,24 @@ class ActorReconstruction(unittest.TestCase):
for j in range(len(actors)):
actor = actors[j]
for _ in range(num_function_calls_at_a_time):
result_ids[actor].append(
actor.inc.remote(j ** 2 * 0.000001))
result_ids[actor].append(actor.inc.remote(j**2 * 0.000001))
# Get the results and check that they have the correct values.
for _, result_id_list in result_ids.items():
self.assertEqual(ray.get(result_id_list),
list(range(1, len(result_id_list) + 1)))
self.assertEqual(
ray.get(result_id_list), list(
range(1,
len(result_id_list) + 1)))
def setup_counter_actor(self, test_checkpoint=False, save_exception=False,
def setup_counter_actor(self,
test_checkpoint=False,
save_exception=False,
resume_exception=False):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=True)
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_workers=0,
redirect_output=True)
# Only set the checkpoint interval if we're testing with checkpointing.
checkpoint_interval = -1
@ -1371,8 +1406,7 @@ class ActorReconstruction(unittest.TestCase):
return actor, ids
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testCheckpointing(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True)
# Wait for the last task to finish running.
@ -1397,8 +1431,7 @@ class ActorReconstruction(unittest.TestCase):
self.assertLess(num_inc_calls, x)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testRemoteCheckpoint(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True)
@ -1424,8 +1457,7 @@ class ActorReconstruction(unittest.TestCase):
self.assertEqual(x, 101)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testLostCheckpoint(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True)
# Wait for the first fraction of tasks to finish running.
@ -1451,11 +1483,10 @@ class ActorReconstruction(unittest.TestCase):
self.assertLess(5, num_inc_calls)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testCheckpointException(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True,
save_exception=True)
actor, ids = self.setup_counter_actor(
test_checkpoint=True, save_exception=True)
# Wait for the last task to finish running.
ray.get(ids[-1])
@ -1481,11 +1512,10 @@ class ActorReconstruction(unittest.TestCase):
self.assertEqual(error[b"type"], b"checkpoint")
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testCheckpointResumeException(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True,
resume_exception=True)
actor, ids = self.setup_counter_actor(
test_checkpoint=True, resume_exception=True)
# Wait for the last task to finish running.
ray.get(ids[-1])
@ -1527,8 +1557,9 @@ class ActorReconstruction(unittest.TestCase):
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
forks = [
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
]
ray.wait(forks, num_returns=len(forks))
count += num_incs * num_iters
@ -1547,8 +1578,7 @@ class ActorReconstruction(unittest.TestCase):
self.assertEqual(x, count + 1)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testRemoteCheckpointDistributedHandle(self):
counter, ids = self.setup_counter_actor(test_checkpoint=True)
@ -1564,8 +1594,9 @@ class ActorReconstruction(unittest.TestCase):
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
forks = [
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
]
ray.wait(forks, num_returns=len(forks))
ray.wait([counter.__ray_checkpoint__.remote()])
count += num_incs * num_iters
@ -1605,8 +1636,9 @@ class ActorReconstruction(unittest.TestCase):
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
forks = [
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
]
ray.wait(forks, num_returns=len(forks))
count += num_incs * num_iters
@ -1624,11 +1656,13 @@ class ActorReconstruction(unittest.TestCase):
x = ray.get(counter.inc.remote())
self.assertEqual(x, count + 1)
def _testNondeterministicReconstruction(self, num_forks,
num_items_per_fork,
num_forks_to_wait):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=True)
def _testNondeterministicReconstruction(
self, num_forks, num_items_per_fork, num_forks_to_wait):
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_workers=0,
redirect_output=True)
# Make a shared queue.
@ray.remote
@ -1668,8 +1702,9 @@ class ActorReconstruction(unittest.TestCase):
# unique objects to push onto the shared queue.
enqueue_tasks = []
for fork in range(num_forks):
enqueue_tasks.append(enqueue.remote(
actor, [(fork, i) for i in range(num_items_per_fork)]))
enqueue_tasks.append(
enqueue.remote(actor,
[(fork, i) for i in range(num_items_per_fork)]))
# Wait for the forks to complete their tasks.
enqueue_tasks = ray.get(enqueue_tasks)
enqueue_tasks = [fork_ids[0] for fork_ids in enqueue_tasks]
@ -1689,8 +1724,8 @@ class ActorReconstruction(unittest.TestCase):
ray.get(enqueue_tasks)
reconstructed_queue = ray.get(actor.read.remote())
# Make sure the final queue has all items from all forks.
self.assertEqual(len(reconstructed_queue), num_forks *
num_items_per_fork)
self.assertEqual(
len(reconstructed_queue), num_forks * num_items_per_fork)
# Make sure that the prefix of the final queue matches the queue from
# the initial execution.
self.assertEqual(queue, reconstructed_queue[:len(queue)])
@ -1709,7 +1744,6 @@ class ActorReconstruction(unittest.TestCase):
class DistributedActorHandles(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -1757,8 +1791,9 @@ class DistributedActorHandles(unittest.TestCase):
# Fork num_iters times.
num_forks = 10
num_items_per_fork = 100
ray.get([fork.remote(queue, i, num_items_per_fork) for i in
range(num_forks)])
ray.get([
fork.remote(queue, i, num_items_per_fork) for i in range(num_forks)
])
items = ray.get(queue.read.remote())
for i in range(num_forks):
filtered_items = [item[1] for item in items if item[0] == i]
@ -1812,7 +1847,6 @@ class DistributedActorHandles(unittest.TestCase):
class ActorPlacementAndResources(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -1836,14 +1870,20 @@ class ActorPlacementAndResources(unittest.TestCase):
actor2s = [Actor2.remote() for _ in range(2)]
results = [a.method.remote() for a in actor2s]
ready_ids, remaining_ids = ray.wait(results, num_returns=len(results),
timeout=1000)
ready_ids, remaining_ids = ray.wait(
results, num_returns=len(results), timeout=1000)
self.assertEqual(len(ready_ids), 1)
def testCustomLabelPlacement(self):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, resources=[{"CustomResource1": 2},
{"CustomResource2": 2}])
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_workers=0,
resources=[{
"CustomResource1": 2
}, {
"CustomResource2": 2
}])
@ray.remote(resources={"CustomResource1": 1})
class ResourceActor1(object):
@ -1868,8 +1908,11 @@ class ActorPlacementAndResources(unittest.TestCase):
self.assertNotEqual(location, local_plasma)
def testCreatingMoreActorsThanResources(self):
ray.init(num_workers=0, num_cpus=10, num_gpus=2,
resources={"CustomResource1": 1})
ray.init(
num_workers=0,
num_cpus=10,
num_gpus=2,
resources={"CustomResource1": 1})
@ray.remote(num_gpus=1)
class ResourceActor1(object):

View file

@ -20,8 +20,9 @@ class RemoteArrayTest(unittest.TestCase):
ray.worker.cleanup()
def testMethods(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
for module in [
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
]:
reload(module)
ray.init()
@ -56,8 +57,9 @@ class DistributedArrayTest(unittest.TestCase):
ray.worker.cleanup()
def testAssemble(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
for module in [
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
]:
reload(module)
ray.init()
@ -66,15 +68,18 @@ class DistributedArrayTest(unittest.TestCase):
x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE],
np.array([[a], [b]]))
assert_equal(x.assemble(),
np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]),
np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])]))
np.vstack([
np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]),
np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])
]))
def testMethods(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
for module in [
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
]:
reload(module)
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_cpus=[10, 10])
ray.worker._init(
start_ray_local=True, num_local_schedulers=2, num_cpus=[10, 10])
x = da.zeros.remote([9, 25, 51], "float")
assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51]))
@ -84,21 +89,23 @@ class DistributedArrayTest(unittest.TestCase):
x = da.random.normal.remote([11, 25, 49])
y = da.copy.remote(x)
assert_equal(ray.get(da.assemble.remote(x)),
ray.get(da.assemble.remote(y)))
assert_equal(
ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y)))
x = da.eye.remote(25, dtype_name="float")
assert_equal(ray.get(da.assemble.remote(x)), np.eye(25))
x = da.random.normal.remote([25, 49])
y = da.triu.remote(x)
assert_equal(ray.get(da.assemble.remote(y)),
np.triu(ray.get(da.assemble.remote(x))))
assert_equal(
ray.get(da.assemble.remote(y)),
np.triu(ray.get(da.assemble.remote(x))))
x = da.random.normal.remote([25, 49])
y = da.tril.remote(x)
assert_equal(ray.get(da.assemble.remote(y)),
np.tril(ray.get(da.assemble.remote(x))))
assert_equal(
ray.get(da.assemble.remote(y)),
np.tril(ray.get(da.assemble.remote(x))))
x = da.random.normal.remote([25, 49])
y = da.random.normal.remote([49, 18])
@ -113,31 +120,31 @@ class DistributedArrayTest(unittest.TestCase):
x = da.random.normal.remote([23, 42])
y = da.random.normal.remote([23, 42])
z = da.add.remote(x, y)
assert_almost_equal(ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) +
ray.get(da.assemble.remote(y)))
assert_almost_equal(
ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y)))
# test subtract
x = da.random.normal.remote([33, 40])
y = da.random.normal.remote([33, 40])
z = da.subtract.remote(x, y)
assert_almost_equal(ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) -
ray.get(da.assemble.remote(y)))
assert_almost_equal(
ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y)))
# test transpose
x = da.random.normal.remote([234, 432])
y = da.transpose.remote(x)
assert_equal(ray.get(da.assemble.remote(x)).T,
ray.get(da.assemble.remote(y)))
assert_equal(
ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y)))
# test numpy_to_dist
x = da.random.normal.remote([23, 45])
y = da.assemble.remote(x)
z = da.numpy_to_dist.remote(y)
w = da.assemble.remote(z)
assert_equal(ray.get(da.assemble.remote(x)),
ray.get(da.assemble.remote(z)))
assert_equal(
ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z)))
assert_equal(ray.get(y), ray.get(w))
# test da.tsqr
@ -157,8 +164,8 @@ class DistributedArrayTest(unittest.TestCase):
# test da.linalg.modified_lu
def test_modified_lu(d1, d2):
print("testing dist_modified_lu with d1 = " + str(d1) +
", d2 = " + str(d2))
print("testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " +
str(d2))
assert d1 >= d2
m = ra.random.normal.remote([d1, d2])
q, r = ra.linalg.qr.remote(m)
@ -178,8 +185,8 @@ class DistributedArrayTest(unittest.TestCase):
# Check that l is lower triangular.
assert_equal(np.tril(l_val), l_val)
for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7),
(20, 10)]:
for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20,
10)]:
test_modified_lu(d1, d2)
# test dist_tsqr_hr

View file

@ -56,7 +56,8 @@ class MockProvider(NodeProvider):
raise Exception("oops")
return [
n.node_id for n in self.mock_nodes.values()
if n.matches(tag_filters) and n.state != "terminated"]
if n.matches(tag_filters) and n.state != "terminated"
]
def is_running(self, node_id):
return self.mock_nodes[node_id].state == "running"
@ -101,7 +102,6 @@ SMALL_CLUSTER = {
"docker": {
"image": "example",
"container_name": "mock",
},
"auth": {
"ssh_user": "ubuntu",
@ -269,8 +269,11 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_concurrent_launches=5,
max_failures=0, update_interval_s=0)
config_path,
LoadMetrics(),
max_concurrent_launches=5,
max_failures=0,
update_interval_s=0)
self.assertEqual(len(self.provider.nodes({})), 0)
autoscaler.update()
self.assertEqual(len(self.provider.nodes({})), 2)
@ -295,8 +298,11 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_concurrent_launches=5,
max_failures=0, update_interval_s=10)
config_path,
LoadMetrics(),
max_concurrent_launches=5,
max_failures=0,
update_interval_s=10)
autoscaler.update()
self.assertEqual(len(self.provider.nodes({})), 2)
new_config = SMALL_CLUSTER.copy()
@ -328,8 +334,11 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_concurrent_launches=10,
max_failures=0, update_interval_s=0)
config_path,
LoadMetrics(),
max_concurrent_launches=10,
max_failures=0,
update_interval_s=0)
autoscaler.update()
# Write a corrupted config
@ -383,16 +392,22 @@ class AutoscalingTest(unittest.TestCase):
self.provider = MockProvider()
runner = MockProcessRunner()
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
config_path,
LoadMetrics(),
max_failures=0,
process_runner=runner,
verbose_updates=True,
node_updater_cls=NodeUpdaterThread,
update_interval_s=0)
autoscaler.update()
autoscaler.update()
self.assertEqual(len(self.provider.nodes({})), 2)
for node in self.provider.mock_nodes.values():
node.state = "running"
assert len(self.provider.nodes(
{TAG_RAY_NODE_STATUS: "Uninitialized"})) == 2
assert len(
self.provider.nodes({
TAG_RAY_NODE_STATUS: "Uninitialized"
})) == 2
autoscaler.update()
self.waitFor(
lambda: len(self.provider.nodes(
@ -403,16 +418,22 @@ class AutoscalingTest(unittest.TestCase):
self.provider = MockProvider()
runner = MockProcessRunner(fail_cmds=["cmd1"])
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
config_path,
LoadMetrics(),
max_failures=0,
process_runner=runner,
verbose_updates=True,
node_updater_cls=NodeUpdaterThread,
update_interval_s=0)
autoscaler.update()
autoscaler.update()
self.assertEqual(len(self.provider.nodes({})), 2)
for node in self.provider.mock_nodes.values():
node.state = "running"
assert len(self.provider.nodes(
{TAG_RAY_NODE_STATUS: "Uninitialized"})) == 2
assert len(
self.provider.nodes({
TAG_RAY_NODE_STATUS: "Uninitialized"
})) == 2
autoscaler.update()
self.waitFor(
lambda: len(self.provider.nodes(
@ -423,8 +444,12 @@ class AutoscalingTest(unittest.TestCase):
self.provider = MockProvider()
runner = MockProcessRunner()
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
config_path,
LoadMetrics(),
max_failures=0,
process_runner=runner,
verbose_updates=True,
node_updater_cls=NodeUpdaterThread,
update_interval_s=0)
autoscaler.update()
autoscaler.update()
@ -490,8 +515,12 @@ class AutoscalingTest(unittest.TestCase):
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path, lm, max_failures=0, process_runner=runner,
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
config_path,
lm,
max_failures=0,
process_runner=runner,
verbose_updates=True,
node_updater_cls=NodeUpdaterThread,
update_interval_s=0)
autoscaler.update()
for node in self.provider.mock_nodes.values():
@ -512,7 +541,7 @@ class AutoscalingTest(unittest.TestCase):
config["provider"] = {
"type": "external",
"module": "ray.autoscaler.node_provider.NodeProvider",
}
}
config_path = self.write_config(config)
autoscaler = StandardAutoscaler(
config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
@ -523,7 +552,7 @@ class AutoscalingTest(unittest.TestCase):
config["provider"] = {
"type": "external",
"module": "mymodule.provider_class",
}
}
invalid_provider = self.write_config(config)
self.assertRaises(
ImportError,
@ -535,7 +564,7 @@ class AutoscalingTest(unittest.TestCase):
config["provider"] = {
"type": "external",
"module": "does-not-exist",
}
}
invalid_provider = self.write_config(config)
self.assertRaises(
ValueError,

View file

@ -11,7 +11,6 @@ import pyarrow as pa
class ComponentFailureTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@ -24,54 +23,20 @@ class ComponentFailureTest(unittest.TestCase):
def f():
ray.worker.global_worker.plasma_client.get(obj_id)
ray.worker._init(num_workers=1,
driver_mode=ray.SILENT_MODE,
start_workers_from_local_scheduler=False,
start_ray_local=True,
redirect_output=True)
ray.worker._init(
num_workers=1,
driver_mode=ray.SILENT_MODE,
start_workers_from_local_scheduler=False,
start_ray_local=True,
redirect_output=True)
# Have the worker wait in a get call.
f.remote()
# Kill the worker.
time.sleep(1)
(ray.services
.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate())
time.sleep(0.1)
# Seal the object so the store attempts to notify the worker that the
# get has been fulfilled.
ray.worker.global_worker.plasma_client.create(
pa.plasma.ObjectID(obj_id), 100)
ray.worker.global_worker.plasma_client.seal(pa.plasma.ObjectID(obj_id))
time.sleep(0.1)
# Make sure that nothing has died.
self.assertTrue(ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
# This test checks that when a worker dies in the middle of a wait, the
# plasma store and manager will not die.
def testDyingWorkerWait(self):
obj_id = 20 * b"a"
@ray.remote
def f():
ray.worker.global_worker.plasma_client.wait([obj_id])
ray.worker._init(num_workers=1,
driver_mode=ray.SILENT_MODE,
start_workers_from_local_scheduler=False,
start_ray_local=True,
redirect_output=True)
# Have the worker wait in a get call.
f.remote()
# Kill the worker.
time.sleep(1)
(ray.services
.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate())
(ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0]
.terminate())
time.sleep(0.1)
# Seal the object so the store attempts to notify the worker that the
@ -82,8 +47,46 @@ class ComponentFailureTest(unittest.TestCase):
time.sleep(0.1)
# Make sure that nothing has died.
self.assertTrue(ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
self.assertTrue(
ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
# This test checks that when a worker dies in the middle of a wait, the
# plasma store and manager will not die.
def testDyingWorkerWait(self):
obj_id = 20 * b"a"
@ray.remote
def f():
ray.worker.global_worker.plasma_client.wait([obj_id])
ray.worker._init(
num_workers=1,
driver_mode=ray.SILENT_MODE,
start_workers_from_local_scheduler=False,
start_ray_local=True,
redirect_output=True)
# Have the worker wait in a get call.
f.remote()
# Kill the worker.
time.sleep(1)
(ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0]
.terminate())
time.sleep(0.1)
# Seal the object so the store attempts to notify the worker that the
# get has been fulfilled.
ray.worker.global_worker.plasma_client.create(
pa.plasma.ObjectID(obj_id), 100)
ray.worker.global_worker.plasma_client.seal(pa.plasma.ObjectID(obj_id))
time.sleep(0.1)
# Make sure that nothing has died.
self.assertTrue(
ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
def _testWorkerFailed(self, num_local_schedulers):
@ray.remote
@ -92,23 +95,25 @@ class ComponentFailureTest(unittest.TestCase):
return x
num_initial_workers = 4
ray.worker._init(num_workers=(num_initial_workers *
num_local_schedulers),
num_local_schedulers=num_local_schedulers,
start_workers_from_local_scheduler=False,
start_ray_local=True,
num_cpus=[num_initial_workers] * num_local_schedulers,
redirect_output=True)
ray.worker._init(
num_workers=(num_initial_workers * num_local_schedulers),
num_local_schedulers=num_local_schedulers,
start_workers_from_local_scheduler=False,
start_ray_local=True,
num_cpus=[num_initial_workers] * num_local_schedulers,
redirect_output=True)
# Submit more tasks than there are workers so that all workers and
# cores are utilized.
object_ids = [f.remote(i) for i
in range(num_initial_workers * num_local_schedulers)]
object_ids = [
f.remote(i)
for i in range(num_initial_workers * num_local_schedulers)
]
object_ids += [f.remote(object_id) for object_id in object_ids]
# Allow the tasks some time to begin executing.
time.sleep(0.1)
# Kill the workers as the tasks execute.
for worker in (ray.services
.all_processes[ray.services.PROCESS_TYPE_WORKER]):
for worker in (
ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]):
worker.terminate()
time.sleep(0.1)
# Make sure that we can still get the objects after the executing tasks
@ -123,6 +128,7 @@ class ComponentFailureTest(unittest.TestCase):
def _testComponentFailed(self, component_type):
"""Kill a component on all worker nodes and check workload succeeds."""
@ray.remote
def f(x, j):
time.sleep(0.2)
@ -140,9 +146,10 @@ class ComponentFailureTest(unittest.TestCase):
# Submit more tasks than there are workers so that all workers and
# cores are utilized.
object_ids = [f.remote(i, 0) for i
in range(num_workers_per_scheduler *
num_local_schedulers)]
object_ids = [
f.remote(i, 0)
for i in range(num_workers_per_scheduler * num_local_schedulers)
]
object_ids += [f.remote(object_id, 1) for object_id in object_ids]
object_ids += [f.remote(object_id, 2) for object_id in object_ids]
@ -162,8 +169,8 @@ class ComponentFailureTest(unittest.TestCase):
# Make sure that we can still get the objects after the executing tasks
# died.
results = ray.get(object_ids)
expected_results = 4 * list(range(
num_workers_per_scheduler * num_local_schedulers))
expected_results = 4 * list(
range(num_workers_per_scheduler * num_local_schedulers))
self.assertEqual(results, expected_results)
def check_components_alive(self, component_type, check_component_alive):
@ -182,8 +189,7 @@ class ComponentFailureTest(unittest.TestCase):
self.assertTrue(not component.poll() is None)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testLocalSchedulerFailed(self):
# Kill all local schedulers on worker nodes.
self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER)
@ -198,8 +204,7 @@ class ComponentFailureTest(unittest.TestCase):
False)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testPlasmaManagerFailed(self):
# Kill all plasma managers on worker nodes.
self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER)
@ -214,8 +219,7 @@ class ComponentFailureTest(unittest.TestCase):
False)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testPlasmaStoreFailed(self):
# Kill all plasma stores on worker nodes.
self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE)
@ -235,7 +239,8 @@ class ComponentFailureTest(unittest.TestCase):
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]
]
# Kill all the components sequentially.
for process in processes:
@ -253,7 +258,8 @@ class ComponentFailureTest(unittest.TestCase):
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]
]
# Kill all the components in parallel.
for process in processes:

View file

@ -9,9 +9,8 @@ import unittest
import ray
@unittest.skipIf(
not os.environ.get('RAY_USE_NEW_GCS', False),
"Tests functionality of the new GCS.")
@unittest.skipIf(not os.environ.get('RAY_USE_NEW_GCS', False),
"Tests functionality of the new GCS.")
class CredisTest(unittest.TestCase):
def setUp(self):
self.config = ray.init(num_workers=0)
@ -22,8 +21,8 @@ class CredisTest(unittest.TestCase):
def test_credis_started(self):
assert "credis_address" in self.config
credis_address, credis_port = self.config["credis_address"].split(":")
credis_client = redis.StrictRedis(host=credis_address,
port=credis_port)
credis_client = redis.StrictRedis(
host=credis_address, port=credis_port)
assert credis_client.ping() is True
redis_client = ray.worker.global_state.redis_client

View file

@ -108,6 +108,7 @@ def temporary_helper_function():
def f(worker):
if ray.worker.global_worker.mode == ray.WORKER_MODE:
raise Exception("Function to run failed.")
ray.worker.global_worker.run_function_on_all_workers(f)
wait_for_errors(b"function_to_run", 2)
# Check that the error message is in the task info.
@ -348,12 +349,14 @@ class PutErrorTest(unittest.TestCase):
ray.worker.cleanup()
def testPutError1(self):
store_size = 10 ** 6
ray.worker._init(start_ray_local=True, driver_mode=ray.SILENT_MODE,
object_store_memory=store_size)
store_size = 10**6
ray.worker._init(
start_ray_local=True,
driver_mode=ray.SILENT_MODE,
object_store_memory=store_size)
num_objects = 3
object_size = 4 * 10 ** 5
object_size = 4 * 10**5
# Define a task with a single dependency, a numpy array, that returns
# another array.
@ -369,8 +372,9 @@ class PutErrorTest(unittest.TestCase):
# on the one before it. The result of the first task should get
# evicted.
args = []
arg = single_dependency.remote(0, np.zeros(object_size,
dtype=np.uint8))
arg = single_dependency.remote(0,
np.zeros(
object_size, dtype=np.uint8))
for i in range(num_objects):
arg = single_dependency.remote(i, arg)
args.append(arg)
@ -393,12 +397,14 @@ class PutErrorTest(unittest.TestCase):
def testPutError2(self):
# This is the same as the previous test, but it calls ray.put directly.
store_size = 10 ** 6
ray.worker._init(start_ray_local=True, driver_mode=ray.SILENT_MODE,
object_store_memory=store_size)
store_size = 10**6
ray.worker._init(
start_ray_local=True,
driver_mode=ray.SILENT_MODE,
object_store_memory=store_size)
num_objects = 3
object_size = 4 * 10 ** 5
object_size = 4 * 10**5
# Define a task with a single dependency, a numpy array, that returns
# another array.

View file

@ -58,6 +58,7 @@ class DockerRunner(object):
head_container_ip: The IP address of the docker container that runs the
head node.
"""
def __init__(self):
"""Initialize the DockerRunner."""
self.head_container_id = None
@ -91,11 +92,14 @@ class DockerRunner(object):
Returns:
The IP address of the container.
"""
proc = subprocess.Popen(["docker", "inspect",
"--format={{.NetworkSettings.Networks.bridge"
".IPAddress}}",
container_id],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.Popen(
[
"docker", "inspect",
"--format={{.NetworkSettings.Networks.bridge"
".IPAddress}}", container_id
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
p = re.compile("([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})")
m = p.match(stdout_data)
@ -110,23 +114,23 @@ class DockerRunner(object):
"""Start the Ray head node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
volume_arg = (["-v",
"{}:{}".format(os.path.dirname(
os.path.realpath(__file__)),
"/ray/test/jenkins_tests")]
if development_mode else [])
volume_arg = ([
"-v", "{}:{}".format(
os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")
] if development_mode else [])
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
[docker_image, "ray", "start", "--head", "--block",
"--redis-port=6379",
"--num-redis-shards={}".format(num_redis_shards),
"--num-cpus={}".format(num_cpus),
"--num-gpus={}".format(num_gpus),
"--no-ui"])
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [
docker_image, "ray", "start", "--head", "--block",
"--redis-port=6379",
"--num-redis-shards={}".format(num_redis_shards),
"--num-cpus={}".format(num_cpus), "--num-gpus={}".format(num_gpus),
"--no-ui"
])
print("Starting head node with command:{}".format(command))
proc = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
container_id = self._get_container_id(stdout_data)
if container_id is None:
@ -139,29 +143,34 @@ class DockerRunner(object):
"""Start a Ray worker node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
volume_arg = (["-v",
"{}:{}".format(os.path.dirname(
os.path.realpath(__file__)),
"/ray/test/jenkins_tests")]
if development_mode else [])
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
["--shm-size=" + shm_size, docker_image,
"ray", "start", "--block",
"--redis-address={:s}:6379".format(self.head_container_ip),
"--num-cpus={}".format(num_cpus),
"--num-gpus={}".format(num_gpus)])
volume_arg = ([
"-v", "{}:{}".format(
os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")
] if development_mode else [])
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [
"--shm-size=" + shm_size, docker_image, "ray", "start", "--block",
"--redis-address={:s}:6379".format(self.head_container_ip),
"--num-cpus={}".format(num_cpus), "--num-gpus={}".format(num_gpus)
])
print("Starting worker node with command:{}".format(command))
proc = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
proc = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
container_id = self._get_container_id(stdout_data)
if container_id is None:
raise RuntimeError("Failed to find container id")
self.worker_container_ids.append(container_id)
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
num_nodes=None, num_redis_shards=1, num_cpus=None,
num_gpus=None, development_mode=None):
def start_ray(self,
docker_image=None,
mem_size=None,
shm_size=None,
num_nodes=None,
num_redis_shards=1,
num_cpus=None,
num_gpus=None,
development_mode=None):
"""Start a Ray cluster within docker.
This starts one docker container running the head node and
@ -200,24 +209,31 @@ class DockerRunner(object):
def _stop_node(self, container_id):
"""Stop a node in the Ray cluster."""
proc = subprocess.Popen(["docker", "kill", container_id],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.Popen(
["docker", "kill", container_id],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
stopped_container_id = self._get_container_id(stdout_data)
if not container_id == stopped_container_id:
raise Exception("Failed to stop container {}."
.format(container_id))
proc = subprocess.Popen(["docker", "rm", "-f", container_id],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.Popen(
["docker", "rm", "-f", container_id],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
removed_container_id = self._get_container_id(stdout_data)
if not container_id == removed_container_id:
raise Exception("Failed to remove container {}."
.format(container_id))
print("stop_node", {"container_id": container_id,
"is_head": container_id == self.head_container_id})
print(
"stop_node", {
"container_id": container_id,
"is_head": container_id == self.head_container_id
})
def stop_ray(self):
"""Stop the Ray cluster."""
@ -236,7 +252,10 @@ class DockerRunner(object):
return success
def run_test(self, test_script, num_drivers, driver_locations=None,
def run_test(self,
test_script,
num_drivers,
driver_locations=None,
timeout_seconds=600):
"""Run a test script.
@ -258,11 +277,13 @@ class DockerRunner(object):
Raises:
Exception: An exception is raised if the timeout expires.
"""
all_container_ids = ([self.head_container_id] +
self.worker_container_ids)
all_container_ids = (
[self.head_container_id] + self.worker_container_ids)
if driver_locations is None:
driver_locations = [np.random.randint(0, len(all_container_ids))
for _ in range(num_drivers)]
driver_locations = [
np.random.randint(0, len(all_container_ids))
for _ in range(num_drivers)
]
# Define a signal handler and set an alarm to go off in
# timeout_seconds.
@ -278,13 +299,15 @@ class DockerRunner(object):
for i in range(len(driver_locations)):
# Get the container ID to run the ith driver in.
container_id = all_container_ids[driver_locations[i]]
command = ["docker", "exec", container_id, "/bin/bash", "-c",
("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python "
"{}".format(self.head_container_ip, i, test_script))]
command = [
"docker", "exec", container_id, "/bin/bash", "-c",
("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python "
"{}".format(self.head_container_ip, i, test_script))
]
print("Starting driver with command {}.".format(test_script))
# Start the driver.
p = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
driver_processes.append(p)
# Wait for the drivers to finish.
@ -295,8 +318,10 @@ class DockerRunner(object):
print(stdout_data)
print("STDERR:")
print(stderr_data)
results.append({"success": p.returncode == 0,
"return_code": p.returncode})
results.append({
"success": p.returncode == 0,
"return_code": p.returncode
})
# Disable the alarm.
signal.alarm(0)
@ -307,29 +332,43 @@ class DockerRunner(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run multinode tests in Docker.")
parser.add_argument("--docker-image", default="ray-project/deploy",
help="docker image")
parser.add_argument(
"--docker-image", default="ray-project/deploy", help="docker image")
parser.add_argument("--mem-size", help="memory size")
parser.add_argument("--shm-size", default="1G", help="shared memory size")
parser.add_argument("--num-nodes", default=1, type=int,
help="number of nodes to use in the cluster")
parser.add_argument("--num-redis-shards", default=1, type=int,
help=("the number of Redis shards to start on the "
"head node"))
parser.add_argument("--num-cpus", type=str,
help=("a comma separated list of values representing "
"the number of CPUs to start each node with"))
parser.add_argument("--num-gpus", type=str,
help=("a comma separated list of values representing "
"the number of GPUs to start each node with"))
parser.add_argument("--num-drivers", default=1, type=int,
help="number of drivers to run")
parser.add_argument("--driver-locations", type=str,
help=("a comma separated list of indices of the "
"containers to run the drivers in"))
parser.add_argument(
"--num-nodes",
default=1,
type=int,
help="number of nodes to use in the cluster")
parser.add_argument(
"--num-redis-shards",
default=1,
type=int,
help=("the number of Redis shards to start on the "
"head node"))
parser.add_argument(
"--num-cpus",
type=str,
help=("a comma separated list of values representing "
"the number of CPUs to start each node with"))
parser.add_argument(
"--num-gpus",
type=str,
help=("a comma separated list of values representing "
"the number of GPUs to start each node with"))
parser.add_argument(
"--num-drivers", default=1, type=int, help="number of drivers to run")
parser.add_argument(
"--driver-locations",
type=str,
help=("a comma separated list of indices of the "
"containers to run the drivers in"))
parser.add_argument("--test-script", required=True, help="test script")
parser.add_argument("--development-mode", action="store_true",
help="use local copies of the test scripts")
parser.add_argument(
"--development-mode",
action="store_true",
help="use local copies of the test scripts")
args = parser.parse_args()
# Parse the number of CPUs and GPUs to use for each worker.
@ -340,18 +379,24 @@ if __name__ == "__main__":
if args.num_gpus is not None else num_nodes * [0])
# Parse the driver locations.
driver_locations = (None if args.driver_locations is None
else [int(i) for i
in args.driver_locations.split(",")])
driver_locations = (None if args.driver_locations is None else
[int(i) for i in args.driver_locations.split(",")])
d = DockerRunner()
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
shm_size=args.shm_size, num_nodes=num_nodes,
num_redis_shards=args.num_redis_shards, num_cpus=num_cpus,
num_gpus=num_gpus, development_mode=args.development_mode)
d.start_ray(
docker_image=args.docker_image,
mem_size=args.mem_size,
shm_size=args.shm_size,
num_nodes=num_nodes,
num_redis_shards=args.num_redis_shards,
num_cpus=num_cpus,
num_gpus=num_gpus,
development_mode=args.development_mode)
try:
run_results = d.run_test(args.test_script, args.num_drivers,
driver_locations=driver_locations)
run_results = d.run_test(
args.test_script,
args.num_drivers,
driver_locations=driver_locations)
finally:
successfully_stopped = d.stop_ray()

View file

@ -6,19 +6,17 @@ import numpy as np
import ray
if __name__ == "__main__":
ray.init(num_workers=0)
A = np.ones(2 ** 31 + 1, dtype="int8")
A = np.ones(2**31 + 1, dtype="int8")
a = ray.put(A)
assert np.sum(ray.get(a)) == np.sum(A)
del A
del a
print("Successfully put A.")
B = {"hello": np.zeros(2 ** 30 + 1),
"world": np.ones(2 ** 30 + 1)}
B = {"hello": np.zeros(2**30 + 1), "world": np.ones(2**30 + 1)}
b = ray.put(B)
assert np.sum(ray.get(b)["hello"]) == np.sum(B["hello"])
assert np.sum(ray.get(b)["world"]) == np.sum(B["world"])
@ -26,7 +24,7 @@ if __name__ == "__main__":
del b
print("Successfully put B.")
C = [np.ones(2 ** 30 + 1), 42.0 * np.ones(2 ** 30 + 1)]
C = [np.ones(2**30 + 1), 42.0 * np.ones(2**30 + 1)]
c = ray.put(C)
assert np.sum(ray.get(c)[0]) == np.sum(C[0])
assert np.sum(ray.get(c)[1]) == np.sum(C[1])

View file

@ -6,8 +6,7 @@ import os
import time
import ray
from ray.test.test_utils import (_wait_for_nodes_to_join,
_broadcast_event,
from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event,
_wait_for_event)
# This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for

View file

@ -6,10 +6,8 @@ import os
import time
import ray
from ray.test.test_utils import (_wait_for_nodes_to_join,
_broadcast_event,
_wait_for_event,
wait_for_pid_to_exit)
from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event,
_wait_for_event, wait_for_pid_to_exit)
# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a
# total of 10 GPUs. It should be run with 7 drivers. Drivers 2 through 6 must
@ -28,9 +26,10 @@ def remote_function_event_name(driver_index, task_index):
@ray.remote
def long_running_task(driver_index, task_index, redis_address):
_broadcast_event(remote_function_event_name(driver_index, task_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
_broadcast_event(
remote_function_event_name(driver_index, task_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
# Loop forever.
while True:
time.sleep(100)
@ -42,10 +41,10 @@ num_long_running_tasks_per_driver = 2
@ray.remote
class Actor0(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(),
os.getpid()))
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 0
def check_ids(self):
@ -60,10 +59,10 @@ class Actor0(object):
@ray.remote(num_gpus=1)
class Actor1(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(),
os.getpid()))
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 1
def check_ids(self):
@ -78,10 +77,10 @@ class Actor1(object):
@ray.remote(num_gpus=2)
class Actor2(object):
def __init__(self, driver_index, actor_index, redis_address):
_broadcast_event(actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(),
os.getpid()))
_broadcast_event(
actor_event_name(driver_index, actor_index),
redis_address,
data=(ray.services.get_node_ip_address(), os.getpid()))
assert len(ray.get_gpu_ids()) == 2
def check_ids(self):
@ -110,11 +109,13 @@ def driver_0(redis_address, driver_index):
long_running_task.remote(driver_index, i, redis_address)
# Create some actors that require one GPU.
actors_one_gpu = [Actor1.remote(driver_index, i, redis_address)
for i in range(5)]
actors_one_gpu = [
Actor1.remote(driver_index, i, redis_address) for i in range(5)
]
# Create some actors that don't require any GPUs.
actors_no_gpus = [Actor0.remote(driver_index, 5 + i, redis_address)
for i in range(5)]
actors_no_gpus = [
Actor0.remote(driver_index, 5 + i, redis_address) for i in range(5)
]
for _ in range(1000):
ray.get([actor.check_ids.remote() for actor in actors_one_gpu])
@ -145,14 +146,17 @@ def driver_1(redis_address, driver_index):
long_running_task.remote(driver_index, i, redis_address)
# Create an actor that requires two GPUs.
actors_two_gpus = [Actor2.remote(driver_index, i, redis_address)
for i in range(1)]
actors_two_gpus = [
Actor2.remote(driver_index, i, redis_address) for i in range(1)
]
# Create some actors that require one GPU.
actors_one_gpu = [Actor1.remote(driver_index, 1 + i, redis_address)
for i in range(3)]
actors_one_gpu = [
Actor1.remote(driver_index, 1 + i, redis_address) for i in range(3)
]
# Create some actors that don't require any GPUs.
actors_no_gpus = [Actor0.remote(driver_index, 1 + 3 + i, redis_address)
for i in range(5)]
actors_no_gpus = [
Actor0.remote(driver_index, 1 + 3 + i, redis_address) for i in range(5)
]
for _ in range(1000):
ray.get([actor.check_ids.remote() for actor in actors_two_gpus])
@ -179,8 +183,9 @@ def cleanup_driver(redis_address, driver_index):
# We go ahead and create some actors that don't require any GPUs. We
# don't need to wait for the other drivers to finish. We call methods
# on these actors later to make sure they haven't been killed.
actors_no_gpus = [Actor0.remote(driver_index, i, redis_address)
for i in range(10)]
actors_no_gpus = [
Actor0.remote(driver_index, i, redis_address) for i in range(10)
]
_wait_for_event("DRIVER_0_DONE", redis_address)
_wait_for_event("DRIVER_1_DONE", redis_address)
@ -206,13 +211,13 @@ def cleanup_driver(redis_address, driver_index):
# Create some actors that require two GPUs.
actors_two_gpus = []
for i in range(3):
actors_two_gpus.append(try_to_create_actor(Actor2, driver_index,
10 + i))
actors_two_gpus.append(
try_to_create_actor(Actor2, driver_index, 10 + i))
# Create some actors that require one GPU.
actors_one_gpu = []
for i in range(4):
actors_one_gpu.append(try_to_create_actor(Actor1, driver_index,
10 + 3 + i))
actors_one_gpu.append(
try_to_create_actor(Actor1, driver_index, 10 + 3 + i))
removed_workers = 0
@ -233,14 +238,14 @@ def cleanup_driver(redis_address, driver_index):
# Make sure that the PIDs for the actors from driver 0 and driver 1 have
# been killed.
for i in range(10):
node_ip_address, pid = _wait_for_event(actor_event_name(0, i),
redis_address)
node_ip_address, pid = _wait_for_event(
actor_event_name(0, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1
for i in range(9):
node_ip_address, pid = _wait_for_event(actor_event_name(1, i),
redis_address)
node_ip_address, pid = _wait_for_event(
actor_event_name(1, i), redis_address)
if node_ip_address == ray.services.get_node_ip_address():
wait_for_pid_to_exit(pid)
removed_workers += 1

View file

@ -25,8 +25,9 @@ if __name__ == "__main__":
for i in range(num_attempts):
ip_addresses = ray.get([f.remote() for i in range(1000)])
distinct_addresses = set(ip_addresses)
counts = [ip_addresses.count(address) for address
in distinct_addresses]
counts = [
ip_addresses.count(address) for address in distinct_addresses
]
print("Counts are {}".format(counts))
if len(counts) == 5:
break

View file

@ -25,19 +25,17 @@ def run_string_as_driver(driver_script):
with tempfile.NamedTemporaryFile() as f:
f.write(driver_script.encode("ascii"))
f.flush()
out = subprocess.check_output([sys.executable,
f.name]).decode("ascii")
out = subprocess.check_output([sys.executable, f.name]).decode("ascii")
return out
class MultiNodeTest(unittest.TestCase):
def setUp(self):
out = run_and_get_output(["ray", "start", "--head"])
# Get the redis address from the output.
redis_substring_prefix = "redis_address=\""
redis_address_location = (out.find(redis_substring_prefix) +
len(redis_substring_prefix))
redis_address_location = (
out.find(redis_substring_prefix) + len(redis_substring_prefix))
redis_address = out[redis_address_location:]
self.redis_address = redis_address.split("\"")[0]
@ -196,7 +194,6 @@ print("success")
class StartRayScriptTest(unittest.TestCase):
def testCallingStartRayHead(self):
# Test that we can call start-ray.sh with various command line
# parameters. TODO(rkn): This test only tests the --head code path. We
@ -207,69 +204,65 @@ class StartRayScriptTest(unittest.TestCase):
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with a number of workers specified.
run_and_get_output(["ray", "start", "--head", "--num-workers",
"20"])
run_and_get_output(["ray", "start", "--head", "--num-workers", "20"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with a redis port specified.
run_and_get_output(["ray", "start", "--head",
"--redis-port", "6379"])
run_and_get_output(["ray", "start", "--head", "--redis-port", "6379"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with redis shard ports specified.
run_and_get_output(["ray", "start", "--head",
"--redis-shard-ports", "6380,6381,6382"])
run_and_get_output([
"ray", "start", "--head", "--redis-shard-ports", "6380,6381,6382"
])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with a node IP address specified.
run_and_get_output(["ray", "start", "--head",
"--node-ip-address", "127.0.0.1"])
run_and_get_output(
["ray", "start", "--head", "--node-ip-address", "127.0.0.1"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with an object manager port specified.
run_and_get_output(["ray", "start", "--head",
"--object-manager-port", "12345"])
run_and_get_output(
["ray", "start", "--head", "--object-manager-port", "12345"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with the number of CPUs specified.
run_and_get_output(["ray", "start", "--head",
"--num-cpus", "100"])
run_and_get_output(["ray", "start", "--head", "--num-cpus", "100"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with the number of GPUs specified.
run_and_get_output(["ray", "start", "--head",
"--num-gpus", "100"])
run_and_get_output(["ray", "start", "--head", "--num-gpus", "100"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with the max redis clients specified.
run_and_get_output(["ray", "start", "--head",
"--redis-max-clients", "100"])
run_and_get_output(
["ray", "start", "--head", "--redis-max-clients", "100"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with all arguments specified.
run_and_get_output(["ray", "start", "--head",
"--num-workers", "20",
"--redis-port", "6379",
"--redis-shard-ports", "6380,6381,6382",
"--object-manager-port", "12345",
"--num-cpus", "100",
"--num-gpus", "0",
"--redis-max-clients", "100",
"--resources", "{\"Custom\": 1}"])
run_and_get_output([
"ray", "start", "--head", "--num-workers", "20", "--redis-port",
"6379", "--redis-shard-ports", "6380,6381,6382",
"--object-manager-port", "12345", "--num-cpus", "100",
"--num-gpus", "0", "--redis-max-clients", "100", "--resources",
"{\"Custom\": 1}"
])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with invalid arguments.
with self.assertRaises(Exception):
run_and_get_output(["ray", "start", "--head",
"--redis-address", "127.0.0.1:6379"])
run_and_get_output([
"ray", "start", "--head", "--redis-address", "127.0.0.1:6379"
])
subprocess.Popen(["ray", "stop"]).wait()
def testUsingHostnames(self):
# Start the Ray processes on this machine.
run_and_get_output(
["ray", "start", "--head",
"--node-ip-address=localhost",
"--redis-port=6379"])
run_and_get_output([
"ray", "start", "--head", "--node-ip-address=localhost",
"--redis-port=6379"
])
ray.init(node_ip_address="localhost", redis_address="localhost:6379")

View file

@ -184,12 +184,12 @@ DICT_OBJECTS = (
[{
obj: obj
} for obj in PRIMITIVE_OBJECTS
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
0:
obj
} for obj in BASE_OBJECTS] + [{
Foo(123): Foo(456)
}])
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] +
[{
0: obj
} for obj in BASE_OBJECTS] + [{
Foo(123): Foo(456)
}])
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
@ -359,25 +359,29 @@ class APITest(unittest.TestCase):
def custom_deserializer(serialized_obj):
return serialized_obj, "string2"
ray.register_custom_serializer(Foo, serializer=custom_serializer,
deserializer=custom_deserializer)
ray.register_custom_serializer(
Foo,
serializer=custom_serializer,
deserializer=custom_deserializer)
self.assertEqual(ray.get(ray.put(Foo())),
((3, "string1", Foo.__name__), "string2"))
self.assertEqual(
ray.get(ray.put(Foo())), ((3, "string1", Foo.__name__), "string2"))
class Bar(object):
def __init__(self):
self.x = 3
ray.register_custom_serializer(Bar, serializer=custom_serializer,
deserializer=custom_deserializer)
ray.register_custom_serializer(
Bar,
serializer=custom_serializer,
deserializer=custom_deserializer)
@ray.remote
def f():
return Bar()
self.assertEqual(ray.get(f.remote()),
((3, "string1", Bar.__name__), "string2"))
self.assertEqual(
ray.get(f.remote()), ((3, "string1", Bar.__name__), "string2"))
def testRegisterClass(self):
self.init_ray(num_workers=2)
@ -700,10 +704,10 @@ class APITest(unittest.TestCase):
assert ray.get(f._submit(args=[1], num_return_vals=1)) == [0]
assert ray.get(f._submit(args=[2], num_return_vals=2)) == [0, 1]
assert ray.get(f._submit(args=[3], num_return_vals=3)) == [0, 1, 2]
assert ray.get(g._submit(args=[],
num_cpus=1,
num_gpus=1,
resources={"Custom": 1})) == [0]
assert ray.get(
g._submit(
args=[], num_cpus=1, num_gpus=1, resources={"Custom":
1})) == [0]
def testGetMultiple(self):
self.init_ray()
@ -1234,8 +1238,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 0
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1245,8 +1249,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 1
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1256,8 +1260,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 2
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1267,8 +1271,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 3
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1278,8 +1282,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 4
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1289,8 +1293,8 @@ class ResourcesTest(unittest.TestCase):
time.sleep(0.1)
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 5
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -1342,16 +1346,16 @@ class ResourcesTest(unittest.TestCase):
def __init__(self):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 0
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
# Set self.x to make sure that we got here.
self.x = 1
def test(self):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 0
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
return self.x
@ray.remote(num_gpus=1)
@ -1359,16 +1363,16 @@ class ResourcesTest(unittest.TestCase):
def __init__(self):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 1
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
# Set self.x to make sure that we got here.
self.x = 1
def test(self):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 1
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
",".join([str(i) for i in gpu_ids]))
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]))
return self.x
a0 = Actor0.remote()
@ -1379,9 +1383,7 @@ class ResourcesTest(unittest.TestCase):
def testZeroCPUs(self):
ray.worker._init(
start_ray_local=True,
num_local_schedulers=2,
num_cpus=[0, 2])
start_ray_local=True, num_local_schedulers=2, num_cpus=[0, 2])
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
@ -1484,9 +1486,9 @@ class ResourcesTest(unittest.TestCase):
elif name == "run_on_2":
self.assertIn(result, [store_names[2]])
elif name == "run_on_0_1_2":
self.assertIn(result, [
store_names[0], store_names[1], store_names[2]
])
self.assertIn(
result,
[store_names[0], store_names[1], store_names[2]])
elif name == "run_on_1_2":
self.assertIn(result, [store_names[1], store_names[2]])
elif name == "run_on_0_2":
@ -1518,7 +1520,11 @@ class ResourcesTest(unittest.TestCase):
start_ray_local=True,
num_local_schedulers=2,
num_cpus=[3, 3],
resources=[{"CustomResource": 0}, {"CustomResource": 1}])
resources=[{
"CustomResource": 0
}, {
"CustomResource": 1
}])
@ray.remote
def f():
@ -1554,8 +1560,13 @@ class ResourcesTest(unittest.TestCase):
start_ray_local=True,
num_local_schedulers=2,
num_cpus=[3, 3],
resources=[{"CustomResource1": 1, "CustomResource2": 2},
{"CustomResource1": 3, "CustomResource2": 4}])
resources=[{
"CustomResource1": 1,
"CustomResource2": 2
}, {
"CustomResource1": 3,
"CustomResource2": 4
}])
@ray.remote(resources={"CustomResource1": 1})
def f():
@ -1595,14 +1606,16 @@ class ResourcesTest(unittest.TestCase):
# Make sure that tasks with unsatisfied custom resource requirements do
# not get scheduled.
ready_ids, remaining_ids = ray.wait([j.remote(), k.remote()],
timeout=500)
ready_ids, remaining_ids = ray.wait(
[j.remote(), k.remote()], timeout=500)
self.assertEqual(ready_ids, [])
def testManyCustomResources(self):
num_custom_resources = 10000
total_resources = {str(i): np.random.randint(1, 7)
for i in range(num_custom_resources)}
total_resources = {
str(i): np.random.randint(1, 7)
for i in range(num_custom_resources)
}
ray.init(num_cpus=5, resources=total_resources)
def f():
@ -1612,9 +1625,11 @@ class ResourcesTest(unittest.TestCase):
for _ in range(20):
num_resources = np.random.randint(0, num_custom_resources + 1)
permuted_resources = np.random.permutation(
num_custom_resources)[:num_resources]
random_resources = {str(i): total_resources[str(i)]
for i in permuted_resources}
num_custom_resources)[:num_resources]
random_resources = {
str(i): total_resources[str(i)]
for i in permuted_resources
}
remote_function = ray.remote(resources=random_resources)(f)
remote_functions.append(remote_function)
@ -1634,8 +1649,7 @@ class CudaVisibleDevicesTest(unittest.TestCase):
def setUp(self):
# Record the curent value of this environment variable so that we can
# reset it after the test.
self.original_gpu_ids = os.environ.get(
"CUDA_VISIBLE_DEVICES", None)
self.original_gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", None)
def tearDown(self):
ray.worker.cleanup()
@ -2095,9 +2109,9 @@ class GlobalStateAPI(unittest.TestCase):
for object_info in object_table.values():
if len(object_info) != 5:
tables_ready = False
if (object_info["ManagerIDs"] is None or
object_info["DataSize"] == -1 or
object_info["Hash"] == ""):
if (object_info["ManagerIDs"] is None
or object_info["DataSize"] == -1
or object_info["Hash"] == ""):
tables_ready = False
if len(task_table) != 10 + 1:

View file

@ -10,14 +10,15 @@ import time
class TaskTests(unittest.TestCase):
def testSubmittingTasks(self):
for num_local_schedulers in [1, 4]:
for num_workers_per_scheduler in [4]:
num_workers = num_local_schedulers * num_workers_per_scheduler
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
ray.worker._init(
start_ray_local=True,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
@ray.remote
def f(x):
@ -42,9 +43,11 @@ class TaskTests(unittest.TestCase):
for num_local_schedulers in [1, 4]:
for num_workers_per_scheduler in [4]:
num_workers = num_local_schedulers * num_workers_per_scheduler
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
ray.worker._init(
start_ray_local=True,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
@ray.remote
def f(x):
@ -89,7 +92,7 @@ class TaskTests(unittest.TestCase):
ray.init(num_workers=1)
for n in range(8):
x = np.zeros(10 ** n)
x = np.zeros(10**n)
for _ in range(100):
ray.put(x)
@ -108,7 +111,7 @@ class TaskTests(unittest.TestCase):
def f():
return 1
n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster.
n = 10**4 # TODO(pcm): replace by 10 ** 5 once this is faster.
lst = ray.get([f.remote() for _ in range(n)])
self.assertEqual(lst, n * [1])
@ -119,9 +122,11 @@ class TaskTests(unittest.TestCase):
for num_local_schedulers in [1, 4]:
for num_workers_per_scheduler in [4]:
num_workers = num_local_schedulers * num_workers_per_scheduler
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
ray.worker._init(
start_ray_local=True,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
num_cpus=100)
@ray.remote
def f(x):
@ -138,8 +143,10 @@ class TaskTests(unittest.TestCase):
time.sleep(x)
for i in range(1, 5):
x_ids = [g.remote(np.random.uniform(0, i))
for _ in range(2 * num_workers)]
x_ids = [
g.remote(np.random.uniform(0, i))
for _ in range(2 * num_workers)
]
ray.wait(x_ids, num_returns=len(x_ids))
self.assertTrue(ray.services.all_processes_alive())
@ -159,34 +166,40 @@ class ReconstructionTests(unittest.TestCase):
time.sleep(0.1)
# Start the Plasma store instances with a total of 1GB memory.
self.plasma_store_memory = 10 ** 9
self.plasma_store_memory = 10**9
plasma_addresses = []
objstore_memory = (self.plasma_store_memory //
self.num_local_schedulers)
objstore_memory = (
self.plasma_store_memory // self.num_local_schedulers)
for i in range(self.num_local_schedulers):
store_stdout_file, store_stderr_file = ray.services.new_log_files(
"plasma_store_{}".format(i), True)
manager_stdout_file, manager_stderr_file = (
ray.services.new_log_files("plasma_manager_{}"
.format(i), True))
plasma_addresses.append(ray.services.start_objstore(
node_ip_address, redis_address,
objstore_memory=objstore_memory,
store_stdout_file=store_stdout_file,
store_stderr_file=store_stderr_file,
manager_stdout_file=manager_stdout_file,
manager_stderr_file=manager_stderr_file))
ray.services.new_log_files("plasma_manager_{}".format(i),
True))
plasma_addresses.append(
ray.services.start_objstore(
node_ip_address,
redis_address,
objstore_memory=objstore_memory,
store_stdout_file=store_stdout_file,
store_stderr_file=store_stderr_file,
manager_stdout_file=manager_stdout_file,
manager_stderr_file=manager_stderr_file))
# Start the rest of the services in the Ray cluster.
address_info = {"redis_address": redis_address,
"redis_shards": redis_shards,
"object_store_addresses": plasma_addresses}
ray.worker._init(address_info=address_info, start_ray_local=True,
num_workers=1,
num_local_schedulers=self.num_local_schedulers,
num_cpus=[1] * self.num_local_schedulers,
redirect_output=True,
driver_mode=ray.SILENT_MODE)
address_info = {
"redis_address": redis_address,
"redis_shards": redis_shards,
"object_store_addresses": plasma_addresses
}
ray.worker._init(
address_info=address_info,
start_ray_local=True,
num_workers=1,
num_local_schedulers=self.num_local_schedulers,
num_cpus=[1] * self.num_local_schedulers,
redirect_output=True,
driver_mode=ray.SILENT_MODE)
def tearDown(self):
self.assertTrue(ray.services.all_processes_alive())
@ -197,8 +210,8 @@ class ReconstructionTests(unittest.TestCase):
state._initialize_global_state(self.redis_ip_address, self.redis_port)
if os.environ.get('RAY_USE_NEW_GCS', False):
tasks = state.task_table()
local_scheduler_ids = set(task["LocalSchedulerID"] for task in
tasks.values())
local_scheduler_ids = set(
task["LocalSchedulerID"] for task in tasks.values())
# Make sure that all nodes in the cluster were used by checking that
# the set of local scheduler IDs that had a task scheduled or submitted
@ -208,8 +221,8 @@ class ReconstructionTests(unittest.TestCase):
# with the driver task, since it is not scheduled by a particular local
# scheduler.
if os.environ.get('RAY_USE_NEW_GCS', False):
self.assertEqual(len(local_scheduler_ids),
self.num_local_schedulers + 1)
self.assertEqual(
len(local_scheduler_ids), self.num_local_schedulers + 1)
# Clean up the Ray cluster.
ray.worker.cleanup()
@ -254,8 +267,7 @@ class ReconstructionTests(unittest.TestCase):
del values
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Failing with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.")
def testRecursive(self):
# Define the size of one task's return argument so that the combined
# sum of all objects' sizes is at least twice the plasma stores'
@ -308,8 +320,7 @@ class ReconstructionTests(unittest.TestCase):
del values
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Failing with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.")
def testMultipleRecursive(self):
# Define the size of one task's return argument so that the combined
# sum of all objects' sizes is at least twice the plasma stores'
@ -375,8 +386,7 @@ class ReconstructionTests(unittest.TestCase):
return errors
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testNondeterministicTask(self):
# Define the size of one task's return argument so that the combined
# sum of all objects' sizes is at least twice the plasma stores'
@ -432,14 +442,14 @@ class ReconstructionTests(unittest.TestCase):
# reexecuted.
min_errors = 1
return len(errors) >= min_errors
errors = self.wait_for_errors(error_check)
# Make sure all the errors have the correct type.
self.assertTrue(all(error[b"type"] == b"object_hash_mismatch"
for error in errors))
self.assertTrue(
all(error[b"type"] == b"object_hash_mismatch" for error in errors))
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
def testDriverPutErrors(self):
# Define the size of one task's return argument so that the combined
# sum of all objects' sizes is at least twice the plasma stores'
@ -479,9 +489,10 @@ class ReconstructionTests(unittest.TestCase):
def error_check(errors):
return len(errors) > 1
errors = self.wait_for_errors(error_check)
self.assertTrue(all(error[b"type"] == b"put_reconstruction"
for error in errors))
self.assertTrue(
all(error[b"type"] == b"put_reconstruction" for error in errors))
class ReconstructionTestsMultinode(ReconstructionTests):
@ -490,6 +501,7 @@ class ReconstructionTestsMultinode(ReconstructionTests):
# one worker each.
num_local_schedulers = 4
# NOTE(swang): This test tries to launch 1000 workers and breaks.
# class WorkerPoolTests(unittest.TestCase):
#
@ -512,6 +524,5 @@ class ReconstructionTestsMultinode(ReconstructionTests):
# ray.get([g.remote(i) for i in range(1000)])
# ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -23,7 +23,6 @@ def make_linear_network(w_name=None, b_name=None):
class LossActor(object):
def __init__(self, use_loss=True):
# Uses a separate graph for each network.
with tf.Graph().as_default():
@ -32,10 +31,8 @@ class LossActor(object):
loss, init, _, _ = make_linear_network()
sess = tf.Session()
# Additional code for setting and getting the weights.
weights = ray.experimental.TensorFlowVariables(loss if use_loss
else None,
sess,
input_variables=var)
weights = ray.experimental.TensorFlowVariables(
loss if use_loss else None, sess, input_variables=var)
# Return all of the data needed to use the network.
self.values = [weights, init, sess]
sess.run(init)
@ -49,7 +46,6 @@ class LossActor(object):
class NetActor(object):
def __init__(self):
# Uses a separate graph for each network.
with tf.Graph().as_default():
@ -71,7 +67,6 @@ class NetActor(object):
class TrainActor(object):
def __init__(self):
# Almost the same as above, but now returns the placeholders and
# gradient.
@ -82,16 +77,17 @@ class TrainActor(object):
optimizer = tf.train.GradientDescentOptimizer(0.9)
grads = optimizer.compute_gradients(loss)
train = optimizer.apply_gradients(grads)
self.values = [loss, variables, init, sess, grads, train,
[x_data, y_data]]
self.values = [
loss, variables, init, sess, grads, train, [x_data, y_data]
]
sess.run(init)
def training_step(self, weights):
_, variables, _, sess, grads, _, placeholders = self.values
variables.set_weights(weights)
return sess.run([grad[0] for grad in grads],
feed_dict=dict(zip(placeholders,
[[1] * 100, [2] * 100])))
return sess.run(
[grad[0] for grad in grads],
feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100])))
def get_weights(self):
return self.values[1].get_weights()
@ -216,8 +212,8 @@ class TensorFlowTest(unittest.TestCase):
net2 = ray.remote(NetActor).remote()
weights2 = ray.get(net2.get_weights.remote())
new_weights2 = ray.get(net2.set_and_get_weights.remote(
net2.get_weights.remote()))
new_weights2 = ray.get(
net2.set_and_get_weights.remote(net2.get_weights.remote()))
self.assertEqual(weights2, new_weights2)
def testVariablesControlDependencies(self):
@ -247,22 +243,26 @@ class TensorFlowTest(unittest.TestCase):
net_values = TrainActor().values
loss, variables, _, sess, grads, train, placeholders = net_values
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
[[2] * 100,
[4] * 100])))
before_acc = sess.run(
loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])))
for _ in range(3):
gradients_list = ray.get(
[net.training_step.remote(variables.get_weights())
for _ in range(2)])
mean_grads = [sum([gradients[i] for gradients in gradients_list]) /
len(gradients_list) for i
in range(len(gradients_list[0]))]
feed_dict = {grad[0]: mean_grad for (grad, mean_grad)
in zip(grads, mean_grads)}
gradients_list = ray.get([
net.training_step.remote(variables.get_weights())
for _ in range(2)
])
mean_grads = [
sum([gradients[i]
for gradients in gradients_list]) / len(gradients_list)
for i in range(len(gradients_list[0]))
]
feed_dict = {
grad[0]: mean_grad
for (grad, mean_grad) in zip(grads, mean_grads)
}
sess.run(train, feed_dict=feed_dict)
after_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
[[2] * 100, [4] * 100])))
after_acc = sess.run(
loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])))
self.assertTrue(before_acc < after_acc)

View file

@ -57,12 +57,11 @@ def test_put_api(ray_start):
# Test putting object IDs.
x_id = ray.put(0)
for obj in [[x_id], (x_id,), {x_id: x_id}]:
for obj in [[x_id], (x_id, ), {x_id: x_id}]:
assert ray.get(ray.put(obj)) == obj
def test_actor_api(ray_start):
@ray.remote
class Foo(object):
def __init__(self, val):