mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Lint Python files with Yapf (#1872)
This commit is contained in:
parent
a3ddde398c
commit
74162d1492
97 changed files with 3927 additions and 3139 deletions
|
@ -42,7 +42,7 @@ blank_line_before_nested_class_or_def=False
|
|||
# 'key1': 'value1',
|
||||
# 'key2': 'value2',
|
||||
# })
|
||||
coalesce_brackets=False
|
||||
coalesce_brackets=True
|
||||
|
||||
# The column limit.
|
||||
column_limit=79
|
||||
|
@ -90,7 +90,7 @@ i18n_function_call=
|
|||
# 'key2': value1 +
|
||||
# value2,
|
||||
# }
|
||||
indent_dictionary_value=False
|
||||
indent_dictionary_value=True
|
||||
|
||||
# The number of columns to use for indentation.
|
||||
indent_width=4
|
||||
|
@ -187,4 +187,3 @@ split_penalty_logical_operator=300
|
|||
|
||||
# Use the Tab character for indentation.
|
||||
use_tabs=False
|
||||
|
||||
|
|
|
@ -38,10 +38,12 @@ matrix:
|
|||
- export PATH="$HOME/miniconda/bin:$PATH"
|
||||
- cd doc
|
||||
- pip install -q -r requirements-doc.txt
|
||||
- pip install yapf
|
||||
- sphinx-build -W -b html -d _build/doctrees source _build/html
|
||||
- cd ..
|
||||
# Run Python linting.
|
||||
- flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/
|
||||
- .travis/yapf.sh
|
||||
|
||||
- os: linux
|
||||
dist: trusty
|
||||
|
|
27
.travis/yapf.sh
Executable file
27
.travis/yapf.sh
Executable file
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -e
|
||||
|
||||
ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd)
|
||||
|
||||
pushd $ROOT_DIR/../test
|
||||
find . -name '*.py' -type f -exec yapf --style=pep8 -i -r {} \;
|
||||
popd
|
||||
|
||||
pushd $ROOT_DIR/../python
|
||||
find . -name '*.py' -type f -not -path './ray/dataframe/*' -not -path './ray/rllib/*' -not -path './ray/cloudpickle/*' -exec yapf --style=pep8 -i -r {} \;
|
||||
popd
|
||||
|
||||
CHANGED_FILES=(`git diff --name-only`)
|
||||
if [ "$CHANGED_FILES" ]; then
|
||||
echo 'Reformatted staged files. Please review and stage the changes.'
|
||||
echo
|
||||
echo 'Files updated:'
|
||||
for file in ${CHANGED_FILES[@]}; do
|
||||
echo " $file"
|
||||
done
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
|
@ -12,8 +12,8 @@ if "pyarrow" in sys.modules:
|
|||
|
||||
# Add the directory containing pyarrow to the Python path so that we find the
|
||||
# pyarrow version packaged with ray and not a pre-existing pyarrow.
|
||||
pyarrow_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
|
||||
"pyarrow_files")
|
||||
pyarrow_path = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)), "pyarrow_files")
|
||||
sys.path.insert(0, pyarrow_path)
|
||||
|
||||
# See https://github.com/ray-project/ray/issues/131.
|
||||
|
@ -27,29 +27,29 @@ If you are using Anaconda, try fixing this problem by running:
|
|||
try:
|
||||
import pyarrow # noqa: F401
|
||||
except ImportError as e:
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
|
||||
("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str)
|
||||
and ("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str)
|
||||
and ("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple)
|
||||
and len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
e.args = (e.args[0] + helpful_message, )
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
e.args = (e.args, )
|
||||
e.args += (helpful_message, )
|
||||
raise
|
||||
|
||||
from ray.local_scheduler import _config # noqa: E402
|
||||
from ray.worker import (error_info, init, connect, disconnect,
|
||||
get, put, wait, remote, log_event, log_span,
|
||||
flush_log, get_gpu_ids, get_webui_url,
|
||||
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
|
||||
remote, log_event, log_span, flush_log, get_gpu_ids,
|
||||
get_webui_url,
|
||||
register_custom_serializer) # noqa: E402
|
||||
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
|
||||
SILENT_MODE) # noqa: E402
|
||||
|
@ -63,11 +63,13 @@ from ray.actor import method # noqa: E402
|
|||
# Fix this.
|
||||
__version__ = "0.4.0"
|
||||
|
||||
__all__ = ["error_info", "init", "connect", "disconnect", "get", "put", "wait",
|
||||
__all__ = [
|
||||
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
|
||||
"remote", "log_event", "log_span", "flush_log", "actor", "method",
|
||||
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
|
||||
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
|
||||
"global_state", "_config", "__version__"]
|
||||
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state",
|
||||
"_config", "__version__"
|
||||
]
|
||||
|
||||
import ctypes # noqa: E402
|
||||
# Windows only
|
||||
|
|
|
@ -121,16 +121,17 @@ def save_and_log_checkpoint(worker, actor):
|
|||
try:
|
||||
actor.__ray_checkpoint__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker.redis_client,
|
||||
"checkpoint",
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id.id(),
|
||||
data={"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint__.__name__})
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint__.__name__
|
||||
})
|
||||
|
||||
|
||||
def restore_and_log_checkpoint(worker, actor):
|
||||
|
@ -144,8 +145,7 @@ def restore_and_log_checkpoint(worker, actor):
|
|||
try:
|
||||
checkpoint_resumed = actor.__ray_checkpoint_restore__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker.redis_client,
|
||||
|
@ -154,8 +154,8 @@ def restore_and_log_checkpoint(worker, actor):
|
|||
driver_id=worker.task_driver_id.id(),
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name":
|
||||
actor.__ray_checkpoint_restore__.__name__})
|
||||
"function_name": actor.__ray_checkpoint_restore__.__name__
|
||||
})
|
||||
return checkpoint_resumed
|
||||
|
||||
|
||||
|
@ -197,15 +197,15 @@ def make_actor_method_executor(worker, method_name, method, actor_imported):
|
|||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported and
|
||||
worker.actor_checkpoint_interval > 0)
|
||||
checkpointing_on = (actor_imported
|
||||
and worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and the
|
||||
# method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (checkpointing_on and
|
||||
(worker.actor_task_counter %
|
||||
worker.actor_checkpoint_interval == 0 and
|
||||
method_name != "__ray_checkpoint__"))
|
||||
save_checkpoint = (
|
||||
checkpointing_on and
|
||||
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
|
@ -238,14 +238,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name,
|
||||
module, pickled_class, checkpoint_interval,
|
||||
actor_method_names,
|
||||
(driver_id, class_id, class_name, module, pickled_class,
|
||||
checkpoint_interval, actor_method_names,
|
||||
actor_method_num_return_vals) = worker.redis_client.hmget(
|
||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||
"class", "checkpoint_interval",
|
||||
"actor_method_names",
|
||||
"actor_method_num_return_vals"])
|
||||
actor_class_key, [
|
||||
"driver_id", "class_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names",
|
||||
"actor_method_num_return_vals"
|
||||
])
|
||||
|
||||
actor_name = class_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
|
@ -259,12 +259,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
# error messages and to prevent the driver from hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(actor_name))
|
||||
|
||||
# Register the actor method signatures.
|
||||
register_actor_signatures(worker, driver_id, class_id, class_name,
|
||||
actor_method_names, actor_method_num_return_vals)
|
||||
|
@ -272,7 +274,8 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
for actor_method_name in actor_method_names:
|
||||
function_id = compute_actor_method_function_id(class_name,
|
||||
actor_method_name).id()
|
||||
temporary_executor = make_actor_method_executor(worker,
|
||||
temporary_executor = make_actor_method_executor(
|
||||
worker,
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
|
@ -288,8 +291,11 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(worker.redis_client, "register_actor_signatures",
|
||||
traceback_str, driver_id,
|
||||
push_error_to_driver(
|
||||
worker.redis_client,
|
||||
"register_actor_signatures",
|
||||
traceback_str,
|
||||
driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker exit
|
||||
# here. However, currently that would lead to hanging if someone calls
|
||||
|
@ -298,16 +304,17 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
actor_methods = inspect.getmembers(
|
||||
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x) or
|
||||
is_cython(x))))
|
||||
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_id = compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id()
|
||||
executor = make_actor_method_executor(worker, actor_method_name,
|
||||
actor_method,
|
||||
actor_imported=True)
|
||||
executor = make_actor_method_executor(
|
||||
worker, actor_method_name, actor_method, actor_imported=True)
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
executor)
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
|
@ -315,7 +322,10 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
|||
# for the actor.
|
||||
|
||||
|
||||
def register_actor_signatures(worker, driver_id, class_id, class_name,
|
||||
def register_actor_signatures(worker,
|
||||
driver_id,
|
||||
class_id,
|
||||
class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources=None,
|
||||
|
@ -346,7 +356,8 @@ def register_actor_signatures(worker, driver_id, class_id, class_name,
|
|||
# The extra return value is an actor dummy object.
|
||||
# In the cases where actor_method_cpus is None, that value should
|
||||
# never be used.
|
||||
FunctionProperties(num_return_vals=num_return_vals + 1,
|
||||
FunctionProperties(
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": actor_method_cpus},
|
||||
max_calls=0))
|
||||
|
||||
|
@ -355,7 +366,8 @@ def register_actor_signatures(worker, driver_id, class_id, class_name,
|
|||
function_id = compute_actor_creation_function_id(class_id)
|
||||
worker.function_properties[driver_id][function_id.id()] = (
|
||||
# The extra return value is an actor dummy object.
|
||||
FunctionProperties(num_return_vals=0 + 1,
|
||||
FunctionProperties(
|
||||
num_return_vals=0 + 1,
|
||||
resources=actor_creation_resources,
|
||||
max_calls=0))
|
||||
|
||||
|
@ -380,8 +392,8 @@ def publish_actor_class_to_key(key, actor_class_info, worker):
|
|||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
checkpoint_interval, worker):
|
||||
actor_method_num_return_vals, checkpoint_interval,
|
||||
worker):
|
||||
key = b"ActorClass:" + class_id
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
|
@ -389,8 +401,9 @@ def export_actor_class(class_id, Class, actor_method_names,
|
|||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"actor_method_names": json.dumps(list(actor_method_names)),
|
||||
"actor_method_num_return_vals": json.dumps(
|
||||
actor_method_num_return_vals)}
|
||||
"actor_method_num_return_vals":
|
||||
json.dumps(actor_method_num_return_vals)
|
||||
}
|
||||
|
||||
if worker.mode is None:
|
||||
# This means that 'ray.init()' has not been called yet and so we must
|
||||
|
@ -433,7 +446,11 @@ def export_actor(actor_id, class_id, class_name, actor_method_names,
|
|||
|
||||
driver_id = worker.task_driver_id.id()
|
||||
register_actor_signatures(
|
||||
worker, driver_id, class_id, class_name, actor_method_names,
|
||||
worker,
|
||||
driver_id,
|
||||
class_id,
|
||||
class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources=actor_creation_resources,
|
||||
actor_method_cpus=actor_method_cpus)
|
||||
|
@ -466,12 +483,14 @@ class ActorMethod(object):
|
|||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actor methods cannot be called directly. Instead "
|
||||
"of running 'object.{}()', try "
|
||||
"'object.{}.remote()'."
|
||||
.format(self._method_name, self._method_name))
|
||||
"'object.{}.remote()'.".format(self._method_name,
|
||||
self._method_name))
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self._actor._actor_method_call(
|
||||
self._method_name, args=args, kwargs=kwargs,
|
||||
self._method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
dependency=self._actor._ray_actor_cursor)
|
||||
|
||||
|
||||
|
@ -481,12 +500,13 @@ class ActorHandleWrapper(object):
|
|||
This is essentially just a dictionary, but it is used so that the recipient
|
||||
can tell that an argument is an ActorHandle.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_id, class_id, actor_handle_id, actor_cursor,
|
||||
actor_counter, actor_method_names,
|
||||
actor_method_num_return_vals, method_signatures,
|
||||
checkpoint_interval, class_name,
|
||||
actor_creation_dummy_object_id,
|
||||
actor_creation_resources, actor_method_cpus):
|
||||
actor_creation_dummy_object_id, actor_creation_resources,
|
||||
actor_method_cpus):
|
||||
# TODO(rkn): Some of these fields are probably not necessary. We should
|
||||
# strip out the unnecessary fields to keep actor handles lightweight.
|
||||
self.actor_id = actor_id
|
||||
|
@ -545,27 +565,20 @@ def unwrap_actor_handle(worker, wrapper):
|
|||
The unwrapped ActorHandle instance.
|
||||
"""
|
||||
driver_id = worker.task_driver_id.id()
|
||||
register_actor_signatures(worker, driver_id, wrapper.class_id,
|
||||
wrapper.class_name, wrapper.actor_method_names,
|
||||
wrapper.actor_method_num_return_vals,
|
||||
wrapper.actor_creation_resources,
|
||||
wrapper.actor_method_cpus)
|
||||
register_actor_signatures(
|
||||
worker, driver_id, wrapper.class_id, wrapper.class_name,
|
||||
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
|
||||
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
|
||||
|
||||
actor_handle_class = make_actor_handle_class(wrapper.class_name)
|
||||
actor_object = actor_handle_class.__new__(actor_handle_class)
|
||||
actor_object._manual_init(
|
||||
wrapper.actor_id,
|
||||
wrapper.class_id,
|
||||
wrapper.actor_handle_id,
|
||||
wrapper.actor_cursor,
|
||||
wrapper.actor_counter,
|
||||
wrapper.actor_method_names,
|
||||
wrapper.actor_method_num_return_vals,
|
||||
wrapper.method_signatures,
|
||||
wrapper.checkpoint_interval,
|
||||
wrapper.actor_id, wrapper.class_id, wrapper.actor_handle_id,
|
||||
wrapper.actor_cursor, wrapper.actor_counter,
|
||||
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
|
||||
wrapper.method_signatures, wrapper.checkpoint_interval,
|
||||
wrapper.actor_creation_dummy_object_id,
|
||||
wrapper.actor_creation_resources,
|
||||
wrapper.actor_method_cpus)
|
||||
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
|
||||
return actor_object
|
||||
|
||||
|
||||
|
@ -612,7 +625,10 @@ def make_actor_handle_class(class_name):
|
|||
self._ray_actor_creation_resources = actor_creation_resources
|
||||
self._ray_actor_method_cpus = actor_method_cpus
|
||||
|
||||
def _actor_method_call(self, method_name, args=None, kwargs=None,
|
||||
def _actor_method_call(self,
|
||||
method_name,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
dependency=None):
|
||||
"""Method execution stub for an actor handle.
|
||||
|
||||
|
@ -663,7 +679,9 @@ def make_actor_handle_class(class_name):
|
|||
function_id = compute_actor_method_function_id(
|
||||
self._ray_class_name, method_name)
|
||||
object_ids = ray.worker.global_worker.submit_task(
|
||||
function_id, args, actor_id=self._ray_actor_id,
|
||||
function_id,
|
||||
args,
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
|
@ -722,8 +740,8 @@ def make_actor_handle_class(class_name):
|
|||
self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID):
|
||||
# TODO(rkn): Should we be passing in the actor cursor as a
|
||||
# dependency here?
|
||||
self._actor_method_call("__ray_terminate__",
|
||||
args=[self._ray_actor_id.id()])
|
||||
self._actor_method_call(
|
||||
"__ray_terminate__", args=[self._ray_actor_id.id()])
|
||||
|
||||
return ActorHandle
|
||||
|
||||
|
@ -735,7 +753,6 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
|||
exported = []
|
||||
|
||||
class ActorHandle(actor_handle_class):
|
||||
|
||||
@classmethod
|
||||
def remote(cls, *args, **kwargs):
|
||||
if ray.worker.global_worker.mode is None:
|
||||
|
@ -754,11 +771,13 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
|||
actor_cursor = None
|
||||
# The number of actor method invocations that we've called so far.
|
||||
actor_counter = 0
|
||||
|
||||
# Get the actor methods of the given class.
|
||||
actor_methods = inspect.getmembers(
|
||||
Class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x) or
|
||||
is_cython(x))))
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
actor_methods = inspect.getmembers(Class, predicate=pred)
|
||||
# Extract the signatures of each of the methods. This will be used
|
||||
# to catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
|
@ -773,8 +792,9 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
|||
method_signatures[k] = signature.extract_signature(
|
||||
v, ignore_first=True)
|
||||
|
||||
actor_method_names = [method_name for method_name, _ in
|
||||
actor_methods]
|
||||
actor_method_names = [
|
||||
method_name for method_name, _ in actor_methods
|
||||
]
|
||||
actor_method_num_return_vals = []
|
||||
for _, method in actor_methods:
|
||||
if hasattr(method, "__ray_num_return_vals__"):
|
||||
|
@ -796,28 +816,27 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
|||
checkpoint_interval,
|
||||
ray.worker.global_worker)
|
||||
exported.append(0)
|
||||
actor_cursor = export_actor(actor_id, class_id, class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources,
|
||||
actor_method_cpus,
|
||||
ray.worker.global_worker)
|
||||
actor_cursor = export_actor(
|
||||
actor_id, class_id, class_name, actor_method_names,
|
||||
actor_method_num_return_vals, actor_creation_resources,
|
||||
actor_method_cpus, ray.worker.global_worker)
|
||||
# Increment the actor counter to account for the creation task.
|
||||
actor_counter += 1
|
||||
|
||||
# Instantiate the actor handle.
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object._manual_init(actor_id, class_id, actor_handle_id,
|
||||
actor_cursor, actor_counter,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
method_signatures, checkpoint_interval,
|
||||
actor_cursor, actor_creation_resources,
|
||||
actor_object._manual_init(
|
||||
actor_id, class_id, actor_handle_id, actor_cursor,
|
||||
actor_counter, actor_method_names,
|
||||
actor_method_num_return_vals, method_signatures,
|
||||
checkpoint_interval, actor_cursor, actor_creation_resources,
|
||||
actor_method_cpus)
|
||||
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in actor_object._ray_actor_method_names:
|
||||
actor_object._actor_method_call("__init__", args=args,
|
||||
actor_object._actor_method_call(
|
||||
"__init__",
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
dependency=actor_cursor)
|
||||
else:
|
||||
|
|
|
@ -51,25 +51,32 @@ CLUSTER_CONFIG_SCHEMA = {
|
|||
"idle_timeout_minutes": (int, OPTIONAL),
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
"provider": ({
|
||||
"provider": (
|
||||
{
|
||||
"type": (str, REQUIRED), # e.g. aws
|
||||
"region": (str, OPTIONAL), # e.g. us-east-1
|
||||
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
|
||||
"module": (str, OPTIONAL), # module, if using external node provider
|
||||
}, REQUIRED),
|
||||
"module": (str,
|
||||
OPTIONAL), # module, if using external node provider
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
"auth": ({
|
||||
"auth": (
|
||||
{
|
||||
"ssh_user": (str, REQUIRED), # e.g. ubuntu
|
||||
"ssh_private_key": (str, OPTIONAL),
|
||||
}, REQUIRED),
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
# Docker configuration. If this is specified, all setup and start commands
|
||||
# will be executed in the container.
|
||||
"docker": ({
|
||||
"docker": (
|
||||
{
|
||||
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
|
||||
"container_name": (str, OPTIONAL), # e.g., ray_docker
|
||||
}, OPTIONAL),
|
||||
},
|
||||
OPTIONAL),
|
||||
|
||||
# Provider-specific config for the head node, e.g. instance type.
|
||||
"head_node": (dict, OPTIONAL),
|
||||
|
@ -137,9 +144,9 @@ class LoadMetrics(object):
|
|||
for unwanted_key in unwanted:
|
||||
del mapping[unwanted_key]
|
||||
if unwanted:
|
||||
print(
|
||||
"Removed {} stale ip mappings: {} not in {}".format(
|
||||
print("Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
|
||||
prune(self.last_used_time_by_ip)
|
||||
prune(self.static_resources_by_ip)
|
||||
prune(self.dynamic_resources_by_ip)
|
||||
|
@ -148,10 +155,8 @@ class LoadMetrics(object):
|
|||
return self._info()["NumNodesUsed"]
|
||||
|
||||
def debug_string(self):
|
||||
return " - {}".format(
|
||||
"\n - ".join(
|
||||
["{}: {}".format(k, v)
|
||||
for k, v in sorted(self._info().items())]))
|
||||
return " - {}".format("\n - ".join(
|
||||
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
|
||||
|
||||
def _info(self):
|
||||
nodes_used = 0.0
|
||||
|
@ -176,14 +181,19 @@ class LoadMetrics(object):
|
|||
nodes_used += max_frac
|
||||
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
|
||||
return {
|
||||
"ResourceUsage": ", ".join([
|
||||
"ResourceUsage":
|
||||
", ".join([
|
||||
"{}/{} {}".format(
|
||||
round(resources_used[rid], 2),
|
||||
round(resources_total[rid], 2), rid)
|
||||
for rid in sorted(resources_used)]),
|
||||
"NumNodesConnected": len(self.static_resources_by_ip),
|
||||
"NumNodesUsed": round(nodes_used, 2),
|
||||
"NodeIdleSeconds": "Min={} Mean={} Max={}".format(
|
||||
for rid in sorted(resources_used)
|
||||
]),
|
||||
"NumNodesConnected":
|
||||
len(self.static_resources_by_ip),
|
||||
"NumNodesUsed":
|
||||
round(nodes_used, 2),
|
||||
"NodeIdleSeconds":
|
||||
"Min={} Mean={} Max={}".format(
|
||||
int(np.min(idle_times)) if idle_times else -1,
|
||||
int(np.mean(idle_times)) if idle_times else -1,
|
||||
int(np.max(idle_times)) if idle_times else -1),
|
||||
|
@ -208,18 +218,20 @@ class StandardAutoscaler(object):
|
|||
until the target cluster size is met).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config_path, load_metrics,
|
||||
def __init__(self,
|
||||
config_path,
|
||||
load_metrics,
|
||||
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
|
||||
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
|
||||
process_runner=subprocess, verbose_updates=False,
|
||||
process_runner=subprocess,
|
||||
verbose_updates=False,
|
||||
node_updater_cls=NodeUpdaterProcess,
|
||||
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
|
||||
self.config_path = config_path
|
||||
self.reload_config(errors_fatal=True)
|
||||
self.load_metrics = load_metrics
|
||||
self.provider = get_node_provider(
|
||||
self.config["provider"], self.config["cluster_name"])
|
||||
self.provider = get_node_provider(self.config["provider"],
|
||||
self.config["cluster_name"])
|
||||
|
||||
self.max_failures = max_failures
|
||||
self.max_concurrent_launches = max_concurrent_launches
|
||||
|
@ -245,8 +257,7 @@ class StandardAutoscaler(object):
|
|||
self.reload_config(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
print(
|
||||
"StandardAutoscaler: Error during autoscaling: {}",
|
||||
print("StandardAutoscaler: Error during autoscaling: {}",
|
||||
traceback.format_exc())
|
||||
self.num_failures += 1
|
||||
if self.num_failures > self.max_failures:
|
||||
|
@ -274,14 +285,12 @@ class StandardAutoscaler(object):
|
|||
if node_ip in last_used and last_used[node_ip] < horizon and \
|
||||
len(nodes) - num_terminated > self.config["min_workers"]:
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating idle node: "
|
||||
print("StandardAutoscaler: Terminating idle node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
elif not self.launch_config_ok(node_id):
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating outdated node: "
|
||||
print("StandardAutoscaler: Terminating outdated node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
if num_terminated > 0:
|
||||
|
@ -292,8 +301,7 @@ class StandardAutoscaler(object):
|
|||
num_terminated = 0
|
||||
while len(nodes) > self.config["max_workers"]:
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating unneeded node: "
|
||||
print("StandardAutoscaler: Terminating unneeded node: "
|
||||
"{}".format(nodes[-1]))
|
||||
self.provider.terminate_node(nodes[-1])
|
||||
nodes = nodes[:-1]
|
||||
|
@ -339,13 +347,13 @@ class StandardAutoscaler(object):
|
|||
with open(self.config_path) as f:
|
||||
new_config = yaml.load(f.read())
|
||||
validate_config(new_config)
|
||||
new_launch_hash = hash_launch_conf(
|
||||
new_config["worker_nodes"], new_config["auth"])
|
||||
new_runtime_hash = hash_runtime_conf(
|
||||
new_config["file_mounts"],
|
||||
[new_config["setup_commands"],
|
||||
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
|
||||
new_config["auth"])
|
||||
new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [
|
||||
new_config["setup_commands"],
|
||||
new_config["worker_setup_commands"],
|
||||
new_config["worker_start_ray_commands"]])
|
||||
new_config["worker_start_ray_commands"]
|
||||
])
|
||||
self.config = new_config
|
||||
self.launch_hash = new_launch_hash
|
||||
self.runtime_hash = new_runtime_hash
|
||||
|
@ -353,16 +361,14 @@ class StandardAutoscaler(object):
|
|||
if errors_fatal:
|
||||
raise e
|
||||
else:
|
||||
print(
|
||||
"StandardAutoscaler: Error parsing config: {}",
|
||||
print("StandardAutoscaler: Error parsing config: {}",
|
||||
traceback.format_exc())
|
||||
|
||||
def target_num_workers(self):
|
||||
target_frac = self.config["target_utilization_fraction"]
|
||||
cur_used = self.load_metrics.approx_workers_used()
|
||||
ideal_num_workers = int(np.ceil(cur_used / float(target_frac)))
|
||||
return min(
|
||||
self.config["max_workers"],
|
||||
return min(self.config["max_workers"],
|
||||
max(self.config["min_workers"], ideal_num_workers))
|
||||
|
||||
def launch_config_ok(self, node_id):
|
||||
|
@ -393,8 +399,7 @@ class StandardAutoscaler(object):
|
|||
node_id,
|
||||
self.config["provider"],
|
||||
self.config["auth"],
|
||||
self.config["cluster_name"],
|
||||
{},
|
||||
self.config["cluster_name"], {},
|
||||
with_head_node_ip(self.config["worker_start_ray_commands"]),
|
||||
self.runtime_hash,
|
||||
redirect_output=not self.verbose_updates,
|
||||
|
@ -409,12 +414,10 @@ class StandardAutoscaler(object):
|
|||
return
|
||||
if self.config.get("no_restart", False) and \
|
||||
self.num_successful_updates.get(node_id, 0) > 0:
|
||||
init_commands = (
|
||||
self.config["setup_commands"] +
|
||||
init_commands = (self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"])
|
||||
else:
|
||||
init_commands = (
|
||||
self.config["setup_commands"] +
|
||||
init_commands = (self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"] +
|
||||
self.config["worker_start_ray_commands"])
|
||||
updater = self.node_updater_cls(
|
||||
|
@ -445,14 +448,12 @@ class StandardAutoscaler(object):
|
|||
print("StandardAutoscaler: Launching {} new nodes".format(count))
|
||||
num_before = len(self.workers())
|
||||
self.provider.create_node(
|
||||
self.config["worker_nodes"],
|
||||
{
|
||||
self.config["worker_nodes"], {
|
||||
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized",
|
||||
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
|
||||
},
|
||||
count)
|
||||
}, count)
|
||||
# TODO(ekl) be less conservative in this check
|
||||
assert len(self.workers()) > num_before, \
|
||||
"Num nodes failed to increase after creating a new node"
|
||||
|
@ -472,8 +473,8 @@ class StandardAutoscaler(object):
|
|||
suffix += " ({} failed to update)".format(
|
||||
len(self.num_failed_updates))
|
||||
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
|
||||
datetime.now(), len(nodes), self.target_num_workers(),
|
||||
suffix, self.load_metrics.debug_string())
|
||||
datetime.now(), len(nodes), self.target_num_workers(), suffix,
|
||||
self.load_metrics.debug_string())
|
||||
|
||||
|
||||
def typename(v):
|
||||
|
@ -507,8 +508,7 @@ def check_extraneous(config, schema):
|
|||
raise ValueError("Config {} is not a dictionary".format(config))
|
||||
for k in config:
|
||||
if k not in schema:
|
||||
raise ValueError(
|
||||
"Unexpected config key `{}` not in {}".format(
|
||||
raise ValueError("Unexpected config key `{}` not in {}".format(
|
||||
k, list(schema.keys())))
|
||||
v, kreq = schema[k]
|
||||
if v is None:
|
||||
|
@ -517,7 +517,8 @@ def check_extraneous(config, schema):
|
|||
if not isinstance(config[k], v):
|
||||
raise ValueError(
|
||||
"Config key `{}` has wrong type {}, expected {}".format(
|
||||
k, type(config[k]).__name__, v.__name__))
|
||||
k,
|
||||
type(config[k]).__name__, v.__name__))
|
||||
else:
|
||||
check_extraneous(config[k], v)
|
||||
|
||||
|
|
|
@ -25,11 +25,9 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
|
|||
def key_pair(i, region):
|
||||
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
|
||||
if i == 0:
|
||||
return (
|
||||
"{}_{}".format(RAY, region),
|
||||
return ("{}_{}".format(RAY, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
|
||||
return (
|
||||
"{}_{}_{}".format(RAY, i, region),
|
||||
return ("{}_{}_{}".format(RAY, i, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
|
||||
|
||||
|
||||
|
@ -83,7 +81,9 @@ def _configure_iam_role(config):
|
|||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "ec2.amazonaws.com"},
|
||||
"Principal": {
|
||||
"Service": "ec2.amazonaws.com"
|
||||
},
|
||||
"Action": "sts:AssumeRole",
|
||||
},
|
||||
],
|
||||
|
@ -97,8 +97,7 @@ def _configure_iam_role(config):
|
|||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
print("Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
print("Role not specified for head node, using {}".format(profile.arn))
|
||||
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
||||
|
||||
return config
|
||||
|
@ -146,8 +145,10 @@ def _configure_key_pair(config):
|
|||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnets = sorted(
|
||||
[s for s in ec2.subnets.all()
|
||||
if s.state == "available" and s.map_public_ip_on_launch],
|
||||
[
|
||||
s for s in ec2.subnets.all()
|
||||
if s.state == "available" and s.map_public_ip_on_launch
|
||||
],
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
if not subnets:
|
||||
|
@ -157,9 +158,9 @@ def _configure_subnet(config):
|
|||
"and trying this again. Note that the subnet must map public IPs "
|
||||
"on instance launch.")
|
||||
if "availability_zone" in config["provider"]:
|
||||
default_subnet = next((s for s in subnets
|
||||
if s.availability_zone ==
|
||||
config["provider"]["availability_zone"]),
|
||||
default_subnet = next((
|
||||
s for s in subnets
|
||||
if s.availability_zone == config["provider"]["availability_zone"]),
|
||||
None)
|
||||
if not default_subnet:
|
||||
raise Exception(
|
||||
|
@ -209,11 +210,21 @@ def _configure_security_group(config):
|
|||
|
||||
if not security_group.ip_permissions:
|
||||
security_group.authorize_ingress(
|
||||
IpPermissions=[
|
||||
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
|
||||
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
|
||||
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
|
||||
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
|
||||
IpPermissions=[{
|
||||
"FromPort": -1,
|
||||
"ToPort": -1,
|
||||
"IpProtocol": "-1",
|
||||
"UserIdGroupPairs": [{
|
||||
"GroupId": security_group.id
|
||||
}]
|
||||
}, {
|
||||
"FromPort": 22,
|
||||
"ToPort": 22,
|
||||
"IpProtocol": "TCP",
|
||||
"IpRanges": [{
|
||||
"CidrIp": "0.0.0.0/0"
|
||||
}]
|
||||
}])
|
||||
|
||||
if "SecurityGroupIds" not in config["head_node"]:
|
||||
print("SecurityGroupIds not specified for head node, using {}".format(
|
||||
|
@ -231,8 +242,10 @@ def _configure_security_group(config):
|
|||
def _get_subnet_or_die(config, subnet_id):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnet = list(
|
||||
ec2.subnets.filter(Filters=[
|
||||
{"Name": "subnet-id", "Values": [subnet_id]}]))
|
||||
ec2.subnets.filter(Filters=[{
|
||||
"Name": "subnet-id",
|
||||
"Values": [subnet_id]
|
||||
}]))
|
||||
assert len(subnet) == 1, "Subnet not found"
|
||||
subnet = subnet[0]
|
||||
return subnet
|
||||
|
@ -241,8 +254,10 @@ def _get_subnet_or_die(config, subnet_id):
|
|||
def _get_security_group(config, vpc_id, group_name):
|
||||
ec2 = _resource("ec2", config)
|
||||
existing_groups = list(
|
||||
ec2.security_groups.filter(Filters=[
|
||||
{"Name": "vpc-id", "Values": [vpc_id]}]))
|
||||
ec2.security_groups.filter(Filters=[{
|
||||
"Name": "vpc-id",
|
||||
"Values": [vpc_id]
|
||||
}]))
|
||||
for sg in existing_groups:
|
||||
if sg.group_name == group_name:
|
||||
return sg
|
||||
|
@ -270,8 +285,10 @@ def _get_instance_profile(profile_name, config):
|
|||
|
||||
def _get_key(key_name, config):
|
||||
ec2 = _resource("ec2", config)
|
||||
for key in ec2.key_pairs.filter(
|
||||
Filters=[{"Name": "key-name", "Values": [key_name]}]):
|
||||
for key in ec2.key_pairs.filter(Filters=[{
|
||||
"Name": "key-name",
|
||||
"Values": [key_name]
|
||||
}]):
|
||||
if key.name == key_name:
|
||||
return key
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ class AWSNodeProvider(NodeProvider):
|
|||
tag_pairs = []
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append({
|
||||
"Key": k, "Value": v,
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
node.create_tags(Tags=tag_pairs)
|
||||
|
||||
|
@ -95,20 +96,20 @@ class AWSNodeProvider(NodeProvider):
|
|||
"Value": self.cluster_name,
|
||||
}]
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append(
|
||||
{
|
||||
tag_pairs.append({
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
conf.update({
|
||||
"MinCount": 1,
|
||||
"MaxCount": count,
|
||||
"TagSpecifications": conf.get("TagSpecifications", []) + [
|
||||
{
|
||||
"MinCount":
|
||||
1,
|
||||
"MaxCount":
|
||||
count,
|
||||
"TagSpecifications":
|
||||
conf.get("TagSpecifications", []) + [{
|
||||
"ResourceType": "instance",
|
||||
"Tags": tag_pairs,
|
||||
}
|
||||
]
|
||||
}]
|
||||
})
|
||||
self.ec2.create_instances(**conf)
|
||||
|
||||
|
|
|
@ -23,9 +23,8 @@ from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
|||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
|
||||
|
||||
def create_or_update_cluster(
|
||||
config_file, override_min_workers, override_max_workers,
|
||||
no_restart, yes):
|
||||
def create_or_update_cluster(config_file, override_min_workers,
|
||||
override_max_workers, no_restart, yes):
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
|
||||
config = yaml.load(open(config_file).read())
|
||||
|
@ -39,8 +38,8 @@ def create_or_update_cluster(
|
|||
|
||||
importer = NODE_PROVIDERS.get(config["provider"]["type"])
|
||||
if not importer:
|
||||
raise NotImplementedError(
|
||||
"Unsupported provider {}".format(config["provider"]))
|
||||
raise NotImplementedError("Unsupported provider {}".format(
|
||||
config["provider"]))
|
||||
|
||||
bootstrap_config, _ = importer()
|
||||
config = bootstrap_config(config)
|
||||
|
@ -129,8 +128,10 @@ def get_or_create_head_node(config, no_restart, yes):
|
|||
remote_config_file.write(json.dumps(remote_config))
|
||||
remote_config_file.flush()
|
||||
config["file_mounts"].update({
|
||||
remote_key_path: config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml": remote_config_file.name
|
||||
remote_key_path:
|
||||
config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml":
|
||||
remote_config_file.name
|
||||
})
|
||||
|
||||
if no_restart:
|
||||
|
@ -160,28 +161,22 @@ def get_or_create_head_node(config, no_restart, yes):
|
|||
print("Error: updating {} failed".format(
|
||||
provider.external_ip(head_node)))
|
||||
sys.exit(1)
|
||||
print(
|
||||
"Head node up-to-date, IP address is: {}".format(
|
||||
print("Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
monitor_str = "tail -f /tmp/raylogs/monitor-*"
|
||||
for s in init_commands:
|
||||
if ("ray start" in s and "docker exec" in s and
|
||||
"--autoscaling-config" in s):
|
||||
if ("ray start" in s and "docker exec" in s
|
||||
and "--autoscaling-config" in s):
|
||||
monitor_str = "docker exec {} /bin/sh -c {}".format(
|
||||
config["docker"]["container_name"],
|
||||
quote(monitor_str))
|
||||
print(
|
||||
"To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ssh -i {} {}@{} {}\n".format(
|
||||
config["auth"]["ssh_private_key"],
|
||||
config["docker"]["container_name"], quote(monitor_str))
|
||||
print("To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ssh -i {} {}@{} {}\n".format(config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node),
|
||||
quote(monitor_str)))
|
||||
print(
|
||||
"To login to the cluster, run:\n\n"
|
||||
" ssh -i {} {}@{}\n".format(
|
||||
config["auth"]["ssh_private_key"],
|
||||
print("To login to the cluster, run:\n\n"
|
||||
" ssh -i {} {}@{}\n".format(config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
|
|
|
@ -22,24 +22,21 @@ def dockerize_if_needed(config):
|
|||
assert cname, "Must provide container name!"
|
||||
docker_mounts = {dst: dst for dst in config["file_mounts"]}
|
||||
config["setup_commands"] = (
|
||||
docker_install_cmds() +
|
||||
docker_start_cmds(
|
||||
config["auth"]["ssh_user"], docker_image,
|
||||
docker_mounts, cname) +
|
||||
with_docker_exec(
|
||||
config["setup_commands"], container_name=cname))
|
||||
docker_install_cmds() + docker_start_cmds(
|
||||
config["auth"]["ssh_user"], docker_image, docker_mounts, cname) +
|
||||
with_docker_exec(config["setup_commands"], container_name=cname))
|
||||
|
||||
config["head_setup_commands"] = with_docker_exec(
|
||||
config["head_setup_commands"], container_name=cname)
|
||||
config["head_start_ray_commands"] = (
|
||||
docker_autoscaler_setup(cname) +
|
||||
with_docker_exec(
|
||||
docker_autoscaler_setup(cname) + with_docker_exec(
|
||||
config["head_start_ray_commands"], container_name=cname))
|
||||
|
||||
config["worker_setup_commands"] = with_docker_exec(
|
||||
config["worker_setup_commands"], container_name=cname)
|
||||
config["worker_start_ray_commands"] = with_docker_exec(
|
||||
config["worker_start_ray_commands"], container_name=cname,
|
||||
config["worker_start_ray_commands"],
|
||||
container_name=cname,
|
||||
env_vars=["RAY_HEAD_IP"])
|
||||
|
||||
return config
|
||||
|
@ -50,18 +47,21 @@ def with_docker_exec(cmds, container_name, env_vars=None):
|
|||
if env_vars:
|
||||
env_str = " ".join(
|
||||
["-e {env}=${env}".format(env=env) for env in env_vars])
|
||||
return ["docker exec {} {} /bin/sh -c {} ".format(
|
||||
env_str, container_name, quote(cmd)) for cmd in cmds]
|
||||
return [
|
||||
"docker exec {} {} /bin/sh -c {} ".format(env_str, container_name,
|
||||
quote(cmd)) for cmd in cmds
|
||||
]
|
||||
|
||||
|
||||
def docker_install_cmds():
|
||||
return [aptwait_cmd() + " && sudo apt-get update",
|
||||
aptwait_cmd() + " && sudo apt-get install -y docker.io"]
|
||||
return [
|
||||
aptwait_cmd() + " && sudo apt-get update",
|
||||
aptwait_cmd() + " && sudo apt-get install -y docker.io"
|
||||
]
|
||||
|
||||
|
||||
def aptwait_cmd():
|
||||
return (
|
||||
"while sudo fuser"
|
||||
return ("while sudo fuser"
|
||||
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
|
||||
" >/dev/null 2>&1; "
|
||||
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
|
||||
|
@ -77,10 +77,12 @@ def docker_start_cmds(user, image, mount, cname):
|
|||
|
||||
# create flags
|
||||
# ports for the redis, object manager, and tune client
|
||||
port_flags = " ".join(["-p {port}:{port}".format(port=port)
|
||||
for port in ["6379", "8076", "4321"]])
|
||||
mount_flags = " ".join(["-v {src}:{dest}".format(src=k, dest=v)
|
||||
for k, v in mount.items()])
|
||||
port_flags = " ".join([
|
||||
"-p {port}:{port}".format(port=port)
|
||||
for port in ["6379", "8076", "4321"]
|
||||
])
|
||||
mount_flags = " ".join(
|
||||
["-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items()])
|
||||
|
||||
# for click, used in ray cli
|
||||
env_vars = {"LC_ALL": "C.UTF-8", "LANG": "C.UTF-8"}
|
||||
|
@ -88,9 +90,10 @@ def docker_start_cmds(user, image, mount, cname):
|
|||
["-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items()])
|
||||
|
||||
# docker run command
|
||||
docker_run = ["docker", "run", "--rm", "--name {}".format(cname),
|
||||
"-d", "-it", port_flags, mount_flags, env_flags,
|
||||
"--net=host", image, "bash"]
|
||||
docker_run = [
|
||||
"docker", "run", "--rm", "--name {}".format(cname), "-d", "-it",
|
||||
port_flags, mount_flags, env_flags, "--net=host", image, "bash"
|
||||
]
|
||||
cmds.append(" ".join(docker_run))
|
||||
docker_update = []
|
||||
docker_update.append("apt-get -y update")
|
||||
|
@ -107,7 +110,8 @@ def docker_autoscaler_setup(cname):
|
|||
base_path = os.path.basename(path)
|
||||
cmds.append("docker cp {path} {cname}:{dpath}".format(
|
||||
path=path, dpath=base_path, cname=cname))
|
||||
cmds.extend(with_docker_exec(
|
||||
cmds.extend(
|
||||
with_docker_exec(
|
||||
["cp {} {}".format("/" + base_path, path)],
|
||||
container_name=cname))
|
||||
return cmds
|
||||
|
|
|
@ -15,14 +15,15 @@ def import_aws():
|
|||
|
||||
def load_aws_config():
|
||||
import ray.autoscaler.aws as ray_aws
|
||||
return os.path.join(os.path.dirname(
|
||||
ray_aws.__file__), "example-full.yaml")
|
||||
return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml")
|
||||
|
||||
|
||||
def import_external():
|
||||
"""Mock a normal provider importer."""
|
||||
|
||||
def return_it_back(config):
|
||||
return config
|
||||
|
||||
return return_it_back, None
|
||||
|
||||
|
||||
|
@ -55,8 +56,7 @@ def load_class(path):
|
|||
class_data = path.split(".")
|
||||
if len(class_data) < 2:
|
||||
raise ValueError(
|
||||
"You need to pass a valid path like mymodule.provider_class"
|
||||
)
|
||||
"You need to pass a valid path like mymodule.provider_class")
|
||||
module_path = ".".join(class_data[:-1])
|
||||
class_str = class_data[-1]
|
||||
module = importlib.import_module(module_path)
|
||||
|
@ -71,8 +71,8 @@ def get_node_provider(provider_config, cluster_name):
|
|||
importer = NODE_PROVIDERS.get(provider_config["type"])
|
||||
|
||||
if importer is None:
|
||||
raise NotImplementedError(
|
||||
"Unsupported node provider: {}".format(provider_config["type"]))
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
_, provider_cls = importer()
|
||||
return provider_cls(provider_config, cluster_name)
|
||||
|
||||
|
@ -82,8 +82,8 @@ def get_default_config(provider_config):
|
|||
return {}
|
||||
load_config = DEFAULT_CONFIGS.get(provider_config["type"])
|
||||
if load_config is None:
|
||||
raise NotImplementedError(
|
||||
"Unsupported node provider: {}".format(provider_config["type"]))
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
path_to_default = load_config()
|
||||
with open(path_to_default) as f:
|
||||
defaults = yaml.load(f)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""The Ray autoscaler uses tags to associate metadata with instances."""
|
||||
|
||||
# Tag for the name of the node
|
||||
|
|
|
@ -26,9 +26,15 @@ def pretty_cmd(cmd_str):
|
|||
class NodeUpdater(object):
|
||||
"""A process for syncing files and running init commands on a node."""
|
||||
|
||||
def __init__(
|
||||
self, node_id, provider_config, auth_config, cluster_name,
|
||||
file_mounts, setup_cmds, runtime_hash, redirect_output=True,
|
||||
def __init__(self,
|
||||
node_id,
|
||||
provider_config,
|
||||
auth_config,
|
||||
cluster_name,
|
||||
file_mounts,
|
||||
setup_cmds,
|
||||
runtime_hash,
|
||||
redirect_output=True,
|
||||
process_runner=subprocess):
|
||||
self.daemon = True
|
||||
self.process_runner = process_runner
|
||||
|
@ -66,13 +72,12 @@ class NodeUpdater(object):
|
|||
"NodeUpdater: Error updating {}"
|
||||
"See {} for remote logs.".format(error_str, self.output_name),
|
||||
file=self.stdout)
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
if self.logfile is not None:
|
||||
print(
|
||||
"----- BEGIN REMOTE LOGS -----\n" +
|
||||
open(self.logfile.name).read() +
|
||||
"\n----- END REMOTE LOGS -----")
|
||||
print("----- BEGIN REMOTE LOGS -----\n" + open(
|
||||
self.logfile.name).read() + "\n----- END REMOTE LOGS -----"
|
||||
)
|
||||
raise e
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
|
@ -85,8 +90,8 @@ class NodeUpdater(object):
|
|||
file=self.stdout)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
|
||||
# Wait for external IP
|
||||
|
@ -114,7 +119,8 @@ class NodeUpdater(object):
|
|||
raise Exception("Node not running yet...")
|
||||
self.ssh_cmd(
|
||||
"uptime",
|
||||
connect_timeout=5, redirect=open("/dev/null", "w"))
|
||||
connect_timeout=5,
|
||||
redirect=open("/dev/null", "w"))
|
||||
ssh_ok = True
|
||||
except Exception as e:
|
||||
retry_str = str(e)
|
||||
|
@ -130,8 +136,8 @@ class NodeUpdater(object):
|
|||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
print(
|
||||
"NodeUpdater: Syncing {} to {}...".format(
|
||||
|
@ -143,18 +149,20 @@ class NodeUpdater(object):
|
|||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.process_runner.check_call([
|
||||
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.process_runner.check_call(
|
||||
[
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
|
||||
"--delete", "-avz", "{}".format(local_path),
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
|
||||
], stdout=self.stdout, stderr=self.stderr)
|
||||
],
|
||||
stdout=self.stdout,
|
||||
stderr=self.stderr)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SettingUp"})
|
||||
for cmd in self.setup_cmds:
|
||||
self.ssh_cmd(cmd, verbose=True)
|
||||
|
||||
|
@ -165,13 +173,16 @@ class NodeUpdater(object):
|
|||
pretty_cmd(cmd), self.ssh_ip),
|
||||
file=self.stdout)
|
||||
force_interactive = "set -i && source ~/.bashrc && "
|
||||
self.process_runner.check_call([
|
||||
self.process_runner.check_call(
|
||||
[
|
||||
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", self.ssh_private_key,
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip),
|
||||
"bash --login -c {}".format(pipes.quote(force_interactive + cmd))
|
||||
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
|
||||
"-i", self.ssh_private_key, "{}@{}".format(
|
||||
self.ssh_user, self.ssh_ip), "bash --login -c {}".format(
|
||||
pipes.quote(force_interactive + cmd))
|
||||
],
|
||||
stdout=redirect or self.stdout,
|
||||
stderr=redirect or self.stderr)
|
||||
|
||||
|
||||
class NodeUpdaterProcess(NodeUpdater, Process):
|
||||
|
|
|
@ -25,7 +25,7 @@ OBJECT_CHANNEL_PREFIX = "OC:"
|
|||
def integerToAsciiHex(num, numbytes):
|
||||
retstr = b""
|
||||
# Support 32 and 64 bit architecture.
|
||||
assert(numbytes == 4 or numbytes == 8)
|
||||
assert (numbytes == 4 or numbytes == 8)
|
||||
for i in range(numbytes):
|
||||
curbyte = num & 0xff
|
||||
if sys.version_info >= (3, 0):
|
||||
|
@ -50,7 +50,6 @@ def get_next_message(pubsub_client, timeout_seconds=10):
|
|||
|
||||
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
|
@ -192,16 +191,16 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
# notifications.
|
||||
def check_object_notification(notification_message, object_id,
|
||||
object_size, manager_ids):
|
||||
notification_object = (SubscribeToNotificationsReply
|
||||
.GetRootAsSubscribeToNotificationsReply(
|
||||
notification_object = (SubscribeToNotificationsReply.
|
||||
GetRootAsSubscribeToNotificationsReply(
|
||||
notification_message, 0))
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(),
|
||||
len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i),
|
||||
manager_ids[i])
|
||||
self.assertEqual(
|
||||
notification_object.ManagerIds(i), manager_ids[i])
|
||||
|
||||
data_size = 0xf1f0
|
||||
p = self.redis.pubsub()
|
||||
|
@ -215,9 +214,8 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id1")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id1", data_size,
|
||||
[b"manager_id2"])
|
||||
|
||||
# Request a notification for an object that isn't there. Then add the
|
||||
|
@ -232,26 +230,22 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id3", data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id2", data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1", b"manager_id2",
|
||||
b"manager_id3"])
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id3", data_size,
|
||||
[b"manager_id1", b"manager_id2", b"manager_id3"])
|
||||
|
||||
def testResultTableAddAndLookup(self):
|
||||
def check_result_table_entry(message, task_id, is_put):
|
||||
|
@ -349,8 +343,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
# update happens, and the response is still the same task.
|
||||
task_args = [task_args[0]] + task_args
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the task entry is still the same.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
|
@ -362,8 +355,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
# task.
|
||||
task_args[1] = TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the update happened.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
|
@ -375,8 +367,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
new_task_args = task_args[:]
|
||||
new_task_args[1] = TASK_STATUS_WAITING
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
"task_id", *new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
|
@ -388,8 +379,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
task_args = new_task_args
|
||||
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
|
||||
# If the test value is a bitmask that does not match the current value,
|
||||
|
@ -399,8 +389,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
new_task_args[0] = TASK_STATUS_SCHEDULED
|
||||
old_response = response
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
"task_id", *new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
|
@ -409,8 +398,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"]
|
||||
task_args = [
|
||||
b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"
|
||||
]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
|
@ -418,8 +409,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
|||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), task_args[0])
|
||||
self.assertEqual(notification_object.State(), task_args[1])
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
task_args[2])
|
||||
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
|
||||
self.assertEqual(notification_object.ExecutionDependencies(),
|
||||
task_args[3])
|
||||
self.assertEqual(notification_object.TaskSpec(), task_args[-1])
|
||||
|
|
|
@ -30,19 +30,23 @@ def random_task_id():
|
|||
|
||||
BASE_SIMPLE_OBJECTS = [
|
||||
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
|
||||
990 * u"h"]
|
||||
990 * u"h"
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
|
||||
BASE_SIMPLE_OBJECTS += [
|
||||
long(0), # noqa: E501,F821
|
||||
long(1), # noqa: E501,F821
|
||||
long(100000), # noqa: E501,F821
|
||||
long(1 << 100) # noqa: E501,F821
|
||||
]
|
||||
|
||||
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS]
|
||||
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
|
||||
|
||||
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
|
||||
LIST_SIMPLE_OBJECTS +
|
||||
TUPLE_SIMPLE_OBJECTS +
|
||||
DICT_SIMPLE_OBJECTS)
|
||||
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS +
|
||||
TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS)
|
||||
|
||||
# Create some complex objects that cannot be serialized by value in tasks.
|
||||
|
||||
|
@ -55,21 +59,20 @@ class Foo(object):
|
|||
pass
|
||||
|
||||
|
||||
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", lst, Foo(),
|
||||
10 * [10 * [10 * [1]]]]
|
||||
BASE_COMPLEX_OBJECTS = [
|
||||
999 * "h", 999 * u"h", lst,
|
||||
Foo(), 10 * [10 * [10 * [1]]]
|
||||
]
|
||||
|
||||
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
|
||||
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
|
||||
TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS]
|
||||
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
|
||||
|
||||
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
|
||||
LIST_COMPLEX_OBJECTS +
|
||||
TUPLE_COMPLEX_OBJECTS +
|
||||
DICT_COMPLEX_OBJECTS)
|
||||
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS +
|
||||
TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS)
|
||||
|
||||
|
||||
class TestSerialization(unittest.TestCase):
|
||||
|
||||
def test_serialize_by_value(self):
|
||||
|
||||
for val in SIMPLE_OBJECTS:
|
||||
|
@ -79,7 +82,6 @@ class TestSerialization(unittest.TestCase):
|
|||
|
||||
|
||||
class TestObjectID(unittest.TestCase):
|
||||
|
||||
def test_create_object_id(self):
|
||||
random_object_id()
|
||||
|
||||
|
@ -95,6 +97,7 @@ class TestObjectID(unittest.TestCase):
|
|||
def h():
|
||||
object_ids[0]
|
||||
return 1
|
||||
|
||||
# Make sure that object IDs cannot be pickled (including functions that
|
||||
# close over object IDs).
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
|
||||
|
@ -113,10 +116,12 @@ class TestObjectID(unittest.TestCase):
|
|||
self.assertNotEqual(x1, y1)
|
||||
|
||||
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids1 = [
|
||||
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
|
||||
]
|
||||
object_ids2 = [
|
||||
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
|
||||
]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
|
@ -129,7 +134,6 @@ class TestObjectID(unittest.TestCase):
|
|||
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
|
||||
def check_task(self, task, function_id, num_return_vals, args):
|
||||
self.assertEqual(function_id.id(), task.function_id().id())
|
||||
retrieved_args = task.arguments()
|
||||
|
@ -148,30 +152,16 @@ class TestTask(unittest.TestCase):
|
|||
parent_id = random_task_id()
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
args_list = [
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"],
|
||||
10 * ["a"], 100 * ["a"], 1000 * ["a"], [
|
||||
1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]
|
||||
], object_ids[:1], object_ids[:2], object_ids[:3],
|
||||
object_ids[:4], object_ids[:5], object_ids[:10],
|
||||
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
|
||||
object_ids[0], "a"
|
||||
], [1, object_ids[0], "a"], [
|
||||
object_ids[0], 1, object_ids[1], "a"
|
||||
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
|
|
|
@ -5,5 +5,7 @@ from __future__ import print_function
|
|||
from .tfutils import TensorFlowVariables
|
||||
from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe
|
||||
|
||||
__all__ = ["TensorFlowVariables", "flush_redis_unsafe",
|
||||
"flush_task_and_object_metadata_unsafe"]
|
||||
__all__ = [
|
||||
"TensorFlowVariables", "flush_redis_unsafe",
|
||||
"flush_task_and_object_metadata_unsafe"
|
||||
]
|
||||
|
|
|
@ -8,6 +8,8 @@ from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
|
|||
triu, tril, blockwise_dot, dot, transpose, add, subtract,
|
||||
numpy_to_dist, subblocks)
|
||||
|
||||
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
|
||||
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
|
||||
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
|
||||
__all__ = [
|
||||
"random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones",
|
||||
"copy", "eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add",
|
||||
"subtract", "numpy_to_dist", "subblocks"
|
||||
]
|
||||
|
|
|
@ -13,8 +13,9 @@ class DistArray(object):
|
|||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
|
||||
for a in self.shape]
|
||||
self.num_blocks = [
|
||||
int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape
|
||||
]
|
||||
if objectids is not None:
|
||||
self.objectids = objectids
|
||||
else:
|
||||
|
@ -56,7 +57,7 @@ class DistArray(object):
|
|||
|
||||
def assemble(self):
|
||||
"""Assemble an array from a distributed array of object IDs."""
|
||||
first_block = ray.get(self.objectids[(0,) * self.ndim])
|
||||
first_block = ray.get(self.objectids[(0, ) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
|
@ -85,8 +86,8 @@ def numpy_to_dist(a):
|
|||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
|
||||
in zip(lower, upper)]])
|
||||
result.objectids[index] = ray.put(
|
||||
a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
|
||||
|
@ -126,12 +127,11 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
|||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0],
|
||||
block_shape[1],
|
||||
dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.eye.remote(
|
||||
block_shape[0], block_shape[1], dtype_name=dtype_name)
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape,
|
||||
dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.zeros.remote(
|
||||
block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -190,8 +190,8 @@ def dot(a, b):
|
|||
"b.ndim = {}.".format(b.ndim))
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
|
||||
"a.shape = {} and b.shape = {}.".format(a.shape,
|
||||
b.shape))
|
||||
"a.shape = {} and b.shape = {}.".format(
|
||||
a.shape, b.shape))
|
||||
shape = [a.shape[0], b.shape[1]]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
|
@ -227,8 +227,8 @@ def subblocks(a, *ranges):
|
|||
"the {}th range is {}.".format(i, ranges[i]))
|
||||
if ranges[i][0] < 0:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be at least 0, but the {}th range is {}."
|
||||
.format(i, ranges[i]))
|
||||
"be at least 0, but the {}th range is {}.".format(
|
||||
i, ranges[i]))
|
||||
if ranges[i][-1] >= a.num_blocks[i]:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be less than the relevant number of blocks, but "
|
||||
|
@ -240,8 +240,8 @@ def subblocks(a, *ranges):
|
|||
for i in range(a.ndim)]
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
|
||||
for i in range(a.ndim)])]
|
||||
result.objectids[index] = a.objectids[tuple(
|
||||
[ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
|
||||
|
@ -249,8 +249,8 @@ def subblocks(a, *ranges):
|
|||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, "
|
||||
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
|
||||
a.shape))
|
||||
"but a.ndim = {}, a.shape = {}.".format(
|
||||
a.ndim, a.shape))
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
|
@ -263,8 +263,8 @@ def transpose(a):
|
|||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same "
|
||||
"shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
"shape, but x1.shape = {}, and x2.shape = {}.".format(
|
||||
x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index],
|
||||
|
|
|
@ -76,9 +76,10 @@ def tsqr(a):
|
|||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], core.BLOCK_SIZE]
|
||||
ith_index //= 2
|
||||
q_block_current = ra.dot.remote(
|
||||
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
|
||||
lower, upper))
|
||||
q_block_current = ra.dot.remote(q_block_current,
|
||||
ra.subarray.remote(
|
||||
q_tree[ith_index, j], lower,
|
||||
upper))
|
||||
q_result.objectids[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, ray.get(r)
|
||||
|
@ -196,8 +197,8 @@ def qr(a):
|
|||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.get(ra.shape.remote(R))
|
||||
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
|
||||
dtype_name=result_dtype)
|
||||
eye_temp = ra.eye.remote(
|
||||
R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objectids[i, i] = R
|
||||
|
@ -220,10 +221,11 @@ def qr(a):
|
|||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = core.subblocks.remote(y_res, [], [i])
|
||||
q = core.subtract.remote(
|
||||
q, core.dot.remote(
|
||||
y_col_block,
|
||||
q,
|
||||
core.dot.remote(y_col_block,
|
||||
core.dot.remote(
|
||||
Ts[i],
|
||||
core.dot.remote(core.transpose.remote(y_col_block), q))))
|
||||
core.dot.remote(
|
||||
core.transpose.remote(y_col_block), q))))
|
||||
|
||||
return ray.get(q), r_res
|
||||
|
|
|
@ -8,6 +8,8 @@ from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
|
|||
copy, tril, triu, diag, transpose, add, subtract, sum,
|
||||
shape, sum_list)
|
||||
|
||||
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
|
||||
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
|
||||
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
__all__ = [
|
||||
"random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", "vstack",
|
||||
"hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add",
|
||||
"subtract", "sum", "shape", "sum_list"
|
||||
]
|
||||
|
|
|
@ -5,10 +5,11 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import ray
|
||||
|
||||
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"multi_dot"]
|
||||
__all__ = [
|
||||
"matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "cholesky",
|
||||
"eigvals", "eigvalsh", "pinv", "slogdet", "det", "svd", "eig", "eigh",
|
||||
"lstsq", "norm", "qr", "cond", "matrix_rank", "multi_dot"
|
||||
]
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
|
@ -59,6 +59,7 @@ class GlobalState(object):
|
|||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
# The redis server storing metadata, such as function table, client
|
||||
|
@ -82,7 +83,9 @@ class GlobalState(object):
|
|||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port,
|
||||
def _initialize_global_state(self,
|
||||
redis_ip_address,
|
||||
redis_port,
|
||||
timeout=20):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
|
||||
|
@ -97,8 +100,8 @@ class GlobalState(object):
|
|||
timeout: The maximum amount of time (in seconds) that we should
|
||||
wait for the keys in Redis to be populated.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -118,8 +121,8 @@ class GlobalState(object):
|
|||
"{}.".format(num_redis_shards))
|
||||
|
||||
# Attempt to get all of the Redis shards.
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||
end=-1)
|
||||
ip_address_ports = self.redis_client.lrange(
|
||||
"RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
print("Waiting longer for RedisShards to be populated.")
|
||||
time.sleep(1)
|
||||
|
@ -132,15 +135,15 @@ class GlobalState(object):
|
|||
if time.time() - start_time >= timeout:
|
||||
raise Exception("Timed out while attempting to initialize the "
|
||||
"global state. num_redis_shards = {}, "
|
||||
"ip_address_ports = {}"
|
||||
.format(num_redis_shards, ip_address_ports))
|
||||
"ip_address_ports = {}".format(
|
||||
num_redis_shards, ip_address_ports))
|
||||
|
||||
# Get the rest of the information.
|
||||
self.redis_clients = []
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
self.redis_clients.append(
|
||||
redis.StrictRedis(host=shard_address, port=shard_port))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
|
@ -152,8 +155,8 @@ class GlobalState(object):
|
|||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
client = self.redis_clients[key.redis_shard_hash() % len(
|
||||
self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
|
||||
def _keys(self, pattern):
|
||||
|
@ -189,8 +192,9 @@ class GlobalState(object):
|
|||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
manager_ids = [
|
||||
binary_to_hex(manager_id) for manager_id in object_locations
|
||||
]
|
||||
else:
|
||||
manager_ids = None
|
||||
|
||||
|
@ -199,11 +203,13 @@ class GlobalState(object):
|
|||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
|
||||
result = {"ManagerIDs": manager_ids,
|
||||
result = {
|
||||
"ManagerIDs": manager_ids,
|
||||
"TaskID": binary_to_hex(result_table_message.TaskId()),
|
||||
"IsPut": bool(result_table_message.IsPut()),
|
||||
"DataSize": result_table_message.DataSize(),
|
||||
"Hash": binary_to_hex(result_table_message.Hash())}
|
||||
"Hash": binary_to_hex(result_table_message.Hash())
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
@ -227,9 +233,10 @@ class GlobalState(object):
|
|||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):]
|
||||
for key in object_location_keys])
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [
|
||||
key[len(OBJECT_LOCATION_PREFIX):]
|
||||
for key in object_location_keys
|
||||
])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = (
|
||||
|
@ -254,26 +261,37 @@ class GlobalState(object):
|
|||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task "
|
||||
"table.".format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
|
||||
0)
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(
|
||||
task_table_response, 0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec = ray.local_scheduler.task_from_string(task_spec)
|
||||
|
||||
task_spec_info = {
|
||||
"DriverID": binary_to_hex(task_spec.driver_id().id()),
|
||||
"TaskID": binary_to_hex(task_spec.task_id().id()),
|
||||
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
|
||||
"ParentCounter": task_spec.parent_counter(),
|
||||
"ActorID": binary_to_hex(task_spec.actor_id().id()),
|
||||
"DriverID":
|
||||
binary_to_hex(task_spec.driver_id().id()),
|
||||
"TaskID":
|
||||
binary_to_hex(task_spec.task_id().id()),
|
||||
"ParentTaskID":
|
||||
binary_to_hex(task_spec.parent_task_id().id()),
|
||||
"ParentCounter":
|
||||
task_spec.parent_counter(),
|
||||
"ActorID":
|
||||
binary_to_hex(task_spec.actor_id().id()),
|
||||
"ActorCreationID":
|
||||
binary_to_hex(task_spec.actor_creation_id().id()),
|
||||
"ActorCreationDummyObjectID":
|
||||
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
|
||||
"ActorCounter": task_spec.actor_counter(),
|
||||
"FunctionID": binary_to_hex(task_spec.function_id().id()),
|
||||
"Args": task_spec.arguments(),
|
||||
"ReturnObjectIDs": task_spec.returns(),
|
||||
"RequiredResources": task_spec.required_resources()}
|
||||
"ActorCounter":
|
||||
task_spec.actor_counter(),
|
||||
"FunctionID":
|
||||
binary_to_hex(task_spec.function_id().id()),
|
||||
"Args":
|
||||
task_spec.arguments(),
|
||||
"ReturnObjectIDs":
|
||||
task_spec.returns(),
|
||||
"RequiredResources":
|
||||
task_spec.required_resources()
|
||||
}
|
||||
|
||||
execution_dependencies_message = (
|
||||
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
|
||||
|
@ -282,21 +300,27 @@ class GlobalState(object):
|
|||
ray.local_scheduler.ObjectID(
|
||||
execution_dependencies_message.ExecutionDependencies(i))
|
||||
for i in range(
|
||||
execution_dependencies_message.ExecutionDependenciesLength())]
|
||||
execution_dependencies_message.ExecutionDependenciesLength())
|
||||
]
|
||||
|
||||
# TODO(rkn): The return fields ExecutionDependenciesString and
|
||||
# ExecutionDependencies are redundant, so we should remove
|
||||
# ExecutionDependencies. However, it is currently used in monitor.py.
|
||||
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
return {
|
||||
"State":
|
||||
task_table_message.State(),
|
||||
"LocalSchedulerID":
|
||||
binary_to_hex(task_table_message.LocalSchedulerId()),
|
||||
"ExecutionDependenciesString":
|
||||
task_table_message.ExecutionDependencies(),
|
||||
"ExecutionDependencies": execution_dependencies,
|
||||
"ExecutionDependencies":
|
||||
execution_dependencies,
|
||||
"SpillbackCount":
|
||||
task_table_message.SpillbackCount(),
|
||||
"TaskSpec": task_spec_info}
|
||||
"TaskSpec":
|
||||
task_spec_info
|
||||
}
|
||||
|
||||
def task_table(self, task_id=None):
|
||||
"""Fetch and parse the task table information for one or more task IDs.
|
||||
|
@ -337,7 +361,8 @@ class GlobalState(object):
|
|||
function_info_parsed = {
|
||||
"DriverID": binary_to_hex(info[b"driver_id"]),
|
||||
"Module": decode(info[b"module"]),
|
||||
"Name": decode(info[b"name"])}
|
||||
"Name": decode(info[b"name"])
|
||||
}
|
||||
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
|
||||
return results
|
||||
|
||||
|
@ -469,21 +494,17 @@ class GlobalState(object):
|
|||
if start is None and end is None:
|
||||
if fwd:
|
||||
event_list = self.redis_client.zrange(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
event_list = self.redis_client.zrevrange(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
if fwd:
|
||||
event_list = self.redis_client.zrangebyscore(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
event_list = self.redis_client.zrevrangebyscore(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
|
||||
for (event, score) in event_list:
|
||||
event_dict = json.loads(event.decode())
|
||||
|
@ -503,11 +524,11 @@ class GlobalState(object):
|
|||
task_info[task_id]["get_task_start"] = event[0]
|
||||
if event[1] == "ray:get_task" and event[2] == 2:
|
||||
task_info[task_id]["get_task_end"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 1):
|
||||
if (event[1] == "ray:import_remote_function"
|
||||
and event[2] == 1):
|
||||
task_info[task_id]["import_remote_start"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 2):
|
||||
if (event[1] == "ray:import_remote_function"
|
||||
and event[2] == 2):
|
||||
task_info[task_id]["import_remote_end"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 1:
|
||||
task_info[task_id]["acquire_lock_start"] = event[0]
|
||||
|
@ -547,7 +568,6 @@ class GlobalState(object):
|
|||
breakdowns=True,
|
||||
task_dep=True,
|
||||
obj_dep=True):
|
||||
|
||||
"""Dump task profiling information to a file.
|
||||
|
||||
This information can be viewed as a timeline of profiling information
|
||||
|
@ -604,18 +624,20 @@ class GlobalState(object):
|
|||
# modify it in place since we will use the original values later.
|
||||
total_info = copy.copy(task_table[task_id]["TaskSpec"])
|
||||
total_info["Args"] = [
|
||||
oid.hex() if isinstance(oid, ray.local_scheduler.ObjectID)
|
||||
else oid for oid in task_t_info["TaskSpec"]["Args"]]
|
||||
oid.hex()
|
||||
if isinstance(oid, ray.local_scheduler.ObjectID) else oid
|
||||
for oid in task_t_info["TaskSpec"]["Args"]
|
||||
]
|
||||
total_info["ReturnObjectIDs"] = [
|
||||
oid.hex() for oid
|
||||
in task_t_info["TaskSpec"]["ReturnObjectIDs"]]
|
||||
oid.hex() for oid in task_t_info["TaskSpec"]["ReturnObjectIDs"]
|
||||
]
|
||||
total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"]
|
||||
total_info["get_arguments"] = (info["get_arguments_end"] -
|
||||
info["get_arguments_start"])
|
||||
total_info["execute"] = (info["execute_end"] -
|
||||
info["execute_start"])
|
||||
total_info["store_outputs"] = (info["store_outputs_end"] -
|
||||
info["store_outputs_start"])
|
||||
total_info["get_arguments"] = (
|
||||
info["get_arguments_end"] - info["get_arguments_start"])
|
||||
total_info["execute"] = (
|
||||
info["execute_end"] - info["execute_start"])
|
||||
total_info["store_outputs"] = (
|
||||
info["store_outputs_end"] - info["store_outputs_start"])
|
||||
total_info["function_name"] = info["function_name"]
|
||||
total_info["worker_id"] = info["worker_id"]
|
||||
|
||||
|
@ -627,49 +649,78 @@ class GlobalState(object):
|
|||
if breakdowns:
|
||||
if "get_arguments_end" in info:
|
||||
get_args_trace = {
|
||||
"cat": "get_arguments",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":get_arguments",
|
||||
"args": total_info,
|
||||
"dur": micros(info["get_arguments_end"] -
|
||||
"cat":
|
||||
"get_arguments",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":get_arguments",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["get_arguments_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname": "rail_idle"
|
||||
"cname":
|
||||
"rail_idle"
|
||||
}
|
||||
full_trace.append(get_args_trace)
|
||||
|
||||
if "store_outputs_end" in info:
|
||||
outputs_trace = {
|
||||
"cat": "store_outputs",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["store_outputs_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":store_outputs",
|
||||
"args": total_info,
|
||||
"dur": micros(info["store_outputs_end"] -
|
||||
"cat":
|
||||
"store_outputs",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["store_outputs_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":store_outputs",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["store_outputs_end"] -
|
||||
info["store_outputs_start"]),
|
||||
"cname": "thread_state_runnable"
|
||||
"cname":
|
||||
"thread_state_runnable"
|
||||
}
|
||||
full_trace.append(outputs_trace)
|
||||
|
||||
if "execute_end" in info:
|
||||
execute_trace = {
|
||||
"cat": "execute",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["execute_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":execute",
|
||||
"args": total_info,
|
||||
"dur": micros(info["execute_end"] -
|
||||
info["execute_start"]),
|
||||
"cname": "rail_animation"
|
||||
"cat":
|
||||
"execute",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["execute_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":execute",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["execute_end"] - info["execute_start"]),
|
||||
"cname":
|
||||
"rail_animation"
|
||||
}
|
||||
full_trace.append(execute_trace)
|
||||
|
||||
|
@ -680,15 +731,20 @@ class GlobalState(object):
|
|||
parent_profile = task_info.get(
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + parent_worker["node_ip_address"],
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
parent_profile and
|
||||
parent_profile["get_arguments_start"] or
|
||||
start_time),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + parent_worker["node_ip_address"],
|
||||
"tid":
|
||||
parent_info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(parent_profile
|
||||
and parent_profile["get_arguments_start"]
|
||||
or start_time),
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (parent_info["worker_id"] +
|
||||
str(micros(min(parent_times))))
|
||||
|
@ -696,32 +752,50 @@ class GlobalState(object):
|
|||
full_trace.append(parent)
|
||||
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"f",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (info["worker_id"] +
|
||||
str(micros(min(parent_times)))),
|
||||
"bp": "e",
|
||||
"cname": "olive"
|
||||
"id":
|
||||
(info["worker_id"] + str(micros(min(parent_times)))),
|
||||
"bp":
|
||||
"e",
|
||||
"cname":
|
||||
"olive"
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
|
||||
task = {
|
||||
"cat": "task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"],
|
||||
"args": total_info,
|
||||
"dur": micros(info["store_outputs_end"] -
|
||||
"cat":
|
||||
"task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"],
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["store_outputs_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname": "thread_state_runnable"
|
||||
"cname":
|
||||
"thread_state_runnable"
|
||||
}
|
||||
full_trace.append(task)
|
||||
|
||||
|
@ -732,15 +806,20 @@ class GlobalState(object):
|
|||
parent_profile = task_info.get(
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + parent_worker["node_ip_address"],
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
parent_profile and
|
||||
parent_profile["get_arguments_start"] or
|
||||
start_time),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + parent_worker["node_ip_address"],
|
||||
"tid":
|
||||
parent_info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(parent_profile
|
||||
and parent_profile["get_arguments_start"]
|
||||
or start_time),
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (parent_info["worker_id"] +
|
||||
str(micros(min(parent_times))))
|
||||
|
@ -748,16 +827,23 @@ class GlobalState(object):
|
|||
full_trace.append(parent)
|
||||
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"f",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (info["worker_id"] +
|
||||
str(micros(min(parent_times)))),
|
||||
"bp": "e"
|
||||
"id":
|
||||
(info["worker_id"] + str(micros(min(parent_times)))),
|
||||
"bp":
|
||||
"e"
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
|
||||
|
@ -775,8 +861,8 @@ class GlobalState(object):
|
|||
seen_obj[arg] += 1
|
||||
owner_task = self._object_table(arg)["TaskID"]
|
||||
if owner_task in task_info:
|
||||
owner_worker = (workers[
|
||||
task_info[owner_task]["worker_id"]])
|
||||
owner_worker = (workers[task_info[owner_task][
|
||||
"worker_id"]])
|
||||
# Adding/subtracting 2 to the time associated
|
||||
# with the beginning/ending of the flow event
|
||||
# is necessary to make the flow events show up
|
||||
|
@ -790,18 +876,26 @@ class GlobalState(object):
|
|||
# duration event that it's associated with, and
|
||||
# the flow event therefore always gets drawn.
|
||||
owner = {
|
||||
"cat": "obj_dependency",
|
||||
"cat":
|
||||
"obj_dependency",
|
||||
"pid": ("Node " +
|
||||
owner_worker["node_ip_address"]),
|
||||
"tid": task_info[owner_task]["worker_id"],
|
||||
"ts": micros_rel(task_info[
|
||||
owner_task]["store_outputs_end"]) - 2,
|
||||
"ph": "s",
|
||||
"name": "ObjectDependency",
|
||||
"tid":
|
||||
task_info[owner_task]["worker_id"],
|
||||
"ts":
|
||||
micros_rel(task_info[owner_task]
|
||||
["store_outputs_end"]) - 2,
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"ObjectDependency",
|
||||
"args": {},
|
||||
"bp": "e",
|
||||
"cname": "cq_build_attempt_failed",
|
||||
"id": "obj" + str(arg) + str(seen_obj[arg])
|
||||
"bp":
|
||||
"e",
|
||||
"cname":
|
||||
"cq_build_attempt_failed",
|
||||
"id":
|
||||
"obj" + str(arg) + str(seen_obj[arg])
|
||||
}
|
||||
full_trace.append(owner)
|
||||
|
||||
|
@ -809,8 +903,8 @@ class GlobalState(object):
|
|||
"cat": "obj_dependency",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
info["get_arguments_start"]) + 2,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]) + 2,
|
||||
"ph": "f",
|
||||
"name": "ObjectDependency",
|
||||
"args": {},
|
||||
|
@ -852,14 +946,10 @@ class GlobalState(object):
|
|||
"""
|
||||
|
||||
keys = [
|
||||
"acquire_lock_start",
|
||||
"acquire_lock_end",
|
||||
"get_arguments_start",
|
||||
"get_arguments_end",
|
||||
"execute_start",
|
||||
"execute_end",
|
||||
"store_outputs_start",
|
||||
"store_outputs_end"]
|
||||
"acquire_lock_start", "acquire_lock_end", "get_arguments_start",
|
||||
"get_arguments_end", "execute_start", "execute_end",
|
||||
"store_outputs_start", "store_outputs_end"
|
||||
]
|
||||
|
||||
latest_timestamp = 0
|
||||
for key in keys:
|
||||
|
@ -877,8 +967,8 @@ class GlobalState(object):
|
|||
local_schedulers = []
|
||||
for ip_address, client_list in clients.items():
|
||||
for client in client_list:
|
||||
if (client["ClientType"] == "local_scheduler" and
|
||||
not client["Deleted"]):
|
||||
if (client["ClientType"] == "local_scheduler"
|
||||
and not client["Deleted"]):
|
||||
local_schedulers.append(client)
|
||||
return local_schedulers
|
||||
|
||||
|
@ -893,8 +983,7 @@ class GlobalState(object):
|
|||
|
||||
workers_data[worker_id] = {
|
||||
"local_scheduler_socket":
|
||||
(worker_info[b"local_scheduler_socket"]
|
||||
.decode("ascii")),
|
||||
(worker_info[b"local_scheduler_socket"].decode("ascii")),
|
||||
"node_ip_address": (worker_info[b"node_ip_address"]
|
||||
.decode("ascii")),
|
||||
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
|
||||
|
@ -923,7 +1012,8 @@ class GlobalState(object):
|
|||
"local_scheduler_id":
|
||||
binary_to_hex(info[b"local_scheduler_id"]),
|
||||
"num_gpus": int(info[b"num_gpus"]),
|
||||
"removed": decode(info[b"removed"]) == "True"}
|
||||
"removed": decode(info[b"removed"]) == "True"
|
||||
}
|
||||
return actor_info
|
||||
|
||||
def _job_length(self):
|
||||
|
@ -932,21 +1022,16 @@ class GlobalState(object):
|
|||
overall_largest = 0
|
||||
num_tasks = 0
|
||||
for event_log_set in event_log_sets:
|
||||
fwd_range = self.redis_client.zrange(event_log_set,
|
||||
start=0,
|
||||
end=0,
|
||||
withscores=True)
|
||||
fwd_range = self.redis_client.zrange(
|
||||
event_log_set, start=0, end=0, withscores=True)
|
||||
overall_smallest = min(overall_smallest, fwd_range[0][1])
|
||||
|
||||
rev_range = self.redis_client.zrevrange(event_log_set,
|
||||
start=0,
|
||||
end=0,
|
||||
withscores=True)
|
||||
rev_range = self.redis_client.zrevrange(
|
||||
event_log_set, start=0, end=0, withscores=True)
|
||||
overall_largest = max(overall_largest, rev_range[0][1])
|
||||
|
||||
num_tasks += self.redis_client.zcount(event_log_set,
|
||||
min=0,
|
||||
max=time.time())
|
||||
num_tasks += self.redis_client.zcount(
|
||||
event_log_set, min=0, max=time.time())
|
||||
if num_tasks is 0:
|
||||
return 0, 0, 0
|
||||
return overall_smallest, overall_largest, num_tasks
|
||||
|
@ -966,8 +1051,10 @@ class GlobalState(object):
|
|||
|
||||
for local_scheduler in local_schedulers:
|
||||
for key, value in local_scheduler.items():
|
||||
if key not in ["ClientType", "Deleted", "DBClientID",
|
||||
"AuxAddress", "LocalSchedulerSocketName"]:
|
||||
if key not in [
|
||||
"ClientType", "Deleted", "DBClientID", "AuxAddress",
|
||||
"LocalSchedulerSocketName"
|
||||
]:
|
||||
resources[key] += value
|
||||
|
||||
return dict(resources)
|
||||
|
|
|
@ -27,6 +27,7 @@ class TensorFlowVariables(object):
|
|||
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
|
||||
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
|
||||
"""
|
||||
|
||||
def __init__(self, loss, sess=None, input_variables=None):
|
||||
"""Creates TensorFlowVariables containing extracted variables.
|
||||
|
||||
|
@ -74,8 +75,10 @@ class TensorFlowVariables(object):
|
|||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
variable_list = [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]
|
||||
variable_list = [
|
||||
v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names
|
||||
]
|
||||
if input_variables is not None:
|
||||
variable_list += input_variables
|
||||
for v in variable_list:
|
||||
|
@ -86,7 +89,8 @@ class TensorFlowVariables(object):
|
|||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
self.placeholders[k] = tf.placeholder(
|
||||
var.value().dtype,
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.assignment_nodes[k] = var.assign(self.placeholders[k])
|
||||
|
@ -105,8 +109,9 @@ class TensorFlowVariables(object):
|
|||
Returns:
|
||||
The length of all flattened variables concatenated.
|
||||
"""
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
return sum([
|
||||
np.prod(v.get_shape().as_list()) for v in self.variables.values()
|
||||
])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
|
@ -122,8 +127,10 @@ class TensorFlowVariables(object):
|
|||
1D Array containing the flattened weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
return np.concatenate([
|
||||
v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()
|
||||
])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array.
|
||||
|
@ -138,9 +145,11 @@ class TensorFlowVariables(object):
|
|||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k] for k, v
|
||||
in self.variables.items()]
|
||||
self.sess.run(list(self.assignment_nodes.values()),
|
||||
placeholders = [
|
||||
self.placeholders[k] for k, v in self.variables.items()
|
||||
]
|
||||
self.sess.run(
|
||||
list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
|
@ -150,8 +159,10 @@ class TensorFlowVariables(object):
|
|||
Dictionary mapping variable names to their weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return {k: v.eval(session=self.sess) for k, v
|
||||
in self.variables.items()}
|
||||
return {
|
||||
k: v.eval(session=self.sess)
|
||||
for k, v in self.variables.items()
|
||||
}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights.
|
||||
|
@ -165,15 +176,19 @@ class TensorFlowVariables(object):
|
|||
weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
assign_list = [self.assignment_nodes[name]
|
||||
for name in new_weights.keys()
|
||||
if name in self.assignment_nodes]
|
||||
assign_list = [
|
||||
self.assignment_nodes[name] for name in new_weights.keys()
|
||||
if name in self.assignment_nodes
|
||||
]
|
||||
assert assign_list, ("No variables in the input matched those in the "
|
||||
"network. Possible cause: Two networks were "
|
||||
"defined in the same TensorFlow graph. To fix "
|
||||
"this, place each network definition in its own "
|
||||
"tf.Graph.")
|
||||
self.sess.run(assign_list,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
self.sess.run(
|
||||
assign_list,
|
||||
feed_dict={
|
||||
self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
if name in self.placeholders
|
||||
})
|
||||
|
|
|
@ -29,9 +29,9 @@ class _EventRecursionContextManager(object):
|
|||
total_time_value = "% total time"
|
||||
total_tasks_value = "% total tasks"
|
||||
|
||||
|
||||
# Function that returns instances of sliders and handles associated events.
|
||||
|
||||
|
||||
def get_sliders(update):
|
||||
# Start_box value indicates the desired start point of queried window.
|
||||
start_box = widgets.FloatText(
|
||||
|
@ -60,18 +60,14 @@ def get_sliders(update):
|
|||
|
||||
# Indicates the number of tasks that the user wants to be returned. Is
|
||||
# disabled when the breakdown_opt value is set to total_time_value.
|
||||
num_tasks_box = widgets.IntText(
|
||||
description="Num Tasks:",
|
||||
disabled=False
|
||||
)
|
||||
num_tasks_box = widgets.IntText(description="Num Tasks:", disabled=False)
|
||||
|
||||
# Dropdown bar that lets the user choose between modifying % of total
|
||||
# time or total number of tasks.
|
||||
breakdown_opt = widgets.Dropdown(
|
||||
options=[total_time_value, total_tasks_value],
|
||||
value=total_tasks_value,
|
||||
description="Selection Options:"
|
||||
)
|
||||
description="Selection Options:")
|
||||
|
||||
# Display box for layout.
|
||||
total_time_box = widgets.VBox([start_box, end_box])
|
||||
|
@ -105,9 +101,9 @@ def get_sliders(update):
|
|||
if event == INIT_EVENT:
|
||||
if breakdown_opt.value == total_tasks_value:
|
||||
num_tasks_box.value = -min(10000, num_tasks)
|
||||
range_slider.value = (int(100 -
|
||||
(100. * -num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (int(
|
||||
100 - (100. * -num_tasks_box.value) / num_tasks),
|
||||
100)
|
||||
else:
|
||||
low, high = map(lambda x: x / 100., range_slider.value)
|
||||
start_box.value = round(diff * low, 2)
|
||||
|
@ -120,8 +116,8 @@ def get_sliders(update):
|
|||
elif start_box.value < 0:
|
||||
start_box.value = 0
|
||||
low, high = range_slider.value
|
||||
range_slider.value = (int((start_box.value * 100.) /
|
||||
diff), high)
|
||||
range_slider.value = (int((start_box.value * 100.) / diff),
|
||||
high)
|
||||
|
||||
# Event was triggered by a change in the end_box value.
|
||||
elif event["owner"] == end_box:
|
||||
|
@ -130,8 +126,8 @@ def get_sliders(update):
|
|||
elif end_box.value > diff:
|
||||
end_box.value = diff
|
||||
low, high = range_slider.value
|
||||
range_slider.value = (low, int((end_box.value * 100.) /
|
||||
diff))
|
||||
range_slider.value = (low,
|
||||
int((end_box.value * 100.) / diff))
|
||||
|
||||
# Event was triggered by a change in the breakdown options
|
||||
# toggle.
|
||||
|
@ -145,9 +141,9 @@ def get_sliders(update):
|
|||
# Make CSS display go back to the default settings.
|
||||
num_tasks_box.layout.display = None
|
||||
num_tasks_box.value = min(10000, num_tasks)
|
||||
range_slider.value = (int(100 -
|
||||
(100. * num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (int(
|
||||
100 - (100. * num_tasks_box.value) / num_tasks),
|
||||
100)
|
||||
else:
|
||||
start_box.disabled = False
|
||||
end_box.disabled = False
|
||||
|
@ -156,10 +152,9 @@ def get_sliders(update):
|
|||
# Make CSS display go back to the default settings.
|
||||
total_time_box.layout.display = None
|
||||
num_tasks_box.layout.display = 'none'
|
||||
range_slider.value = (int((start_box.value * 100.) /
|
||||
diff),
|
||||
int((end_box.value * 100.) /
|
||||
diff))
|
||||
range_slider.value = (
|
||||
int((start_box.value * 100.) / diff),
|
||||
int((end_box.value * 100.) / diff))
|
||||
|
||||
# Event was triggered by a change in the range_slider
|
||||
# value.
|
||||
|
@ -170,8 +165,8 @@ def get_sliders(update):
|
|||
new_low, new_high = event["new"]
|
||||
if old_low != new_low:
|
||||
range_slider.value = (new_low, 100)
|
||||
num_tasks_box.value = (-(100. - new_low) /
|
||||
100. * num_tasks)
|
||||
num_tasks_box.value = (
|
||||
-(100. - new_low) / 100. * num_tasks)
|
||||
else:
|
||||
range_slider.value = (0, new_high)
|
||||
num_tasks_box.value = new_high / 100. * num_tasks
|
||||
|
@ -183,14 +178,12 @@ def get_sliders(update):
|
|||
# value.
|
||||
elif event["owner"] == num_tasks_box:
|
||||
if num_tasks_box.value > 0:
|
||||
range_slider.value = (0, int(100 *
|
||||
float(num_tasks_box.value) /
|
||||
num_tasks))
|
||||
range_slider.value = (
|
||||
0, int(
|
||||
100 * float(num_tasks_box.value) / num_tasks))
|
||||
elif num_tasks_box.value < 0:
|
||||
range_slider.value = (100 +
|
||||
int(100 *
|
||||
float(num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (100 + int(
|
||||
100 * float(num_tasks_box.value) / num_tasks), 100)
|
||||
|
||||
if not update:
|
||||
return
|
||||
|
@ -205,22 +198,19 @@ def get_sliders(update):
|
|||
# box values.
|
||||
# (Querying based on the % total amount of time.)
|
||||
if breakdown_opt.value == total_time_value:
|
||||
tasks = _truncated_task_profiles(start=(smallest +
|
||||
diff * low),
|
||||
end=(smallest +
|
||||
diff * high))
|
||||
tasks = _truncated_task_profiles(
|
||||
start=(smallest + diff * low),
|
||||
end=(smallest + diff * high))
|
||||
|
||||
# (Querying based on % of total number of tasks that were
|
||||
# run.)
|
||||
elif breakdown_opt.value == total_tasks_value:
|
||||
if range_slider.value[0] == 0:
|
||||
tasks = _truncated_task_profiles(num_tasks=(int(
|
||||
num_tasks * high)),
|
||||
fwd=True)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=(int(num_tasks * high)), fwd=True)
|
||||
else:
|
||||
tasks = _truncated_task_profiles(num_tasks=(int(
|
||||
num_tasks *
|
||||
(high - low))),
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=(int(num_tasks * (high - low))),
|
||||
fwd=False)
|
||||
|
||||
update(smallest, largest, num_tasks, tasks)
|
||||
|
@ -237,8 +227,8 @@ def get_sliders(update):
|
|||
update_wrapper(INIT_EVENT)
|
||||
|
||||
# Display sliders and search boxes
|
||||
display(breakdown_opt, widgets.HBox([range_slider, total_time_box,
|
||||
num_tasks_box]))
|
||||
display(breakdown_opt,
|
||||
widgets.HBox([range_slider, total_time_box, num_tasks_box]))
|
||||
|
||||
# Return the sliders and text boxes
|
||||
return start_box, end_box, range_slider, breakdown_opt
|
||||
|
@ -249,8 +239,7 @@ def object_search_bar():
|
|||
value="",
|
||||
placeholder="Object ID",
|
||||
description="Search for an object:",
|
||||
disabled=False
|
||||
)
|
||||
disabled=False)
|
||||
display(object_search)
|
||||
|
||||
def handle_submit(sender):
|
||||
|
@ -265,8 +254,7 @@ def task_search_bar():
|
|||
value="",
|
||||
placeholder="Task ID",
|
||||
description="Search for a task:",
|
||||
disabled=False
|
||||
)
|
||||
disabled=False)
|
||||
display(task_search)
|
||||
|
||||
def handle_submit(sender):
|
||||
|
@ -284,12 +272,10 @@ MAX_TASKS_TO_VISUALIZE = 10000
|
|||
def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
|
||||
if num_tasks is None:
|
||||
num_tasks = MAX_TASKS_TO_VISUALIZE
|
||||
print(
|
||||
"Warning: at most {} tasks will be fetched within this "
|
||||
print("Warning: at most {} tasks will be fetched within this "
|
||||
"time range.".format(MAX_TASKS_TO_VISUALIZE))
|
||||
elif num_tasks > MAX_TASKS_TO_VISUALIZE:
|
||||
print(
|
||||
"Warning: too many tasks to visualize, "
|
||||
print("Warning: too many tasks to visualize, "
|
||||
"fetching only the first {} of {}.".format(
|
||||
MAX_TASKS_TO_VISUALIZE, num_tasks))
|
||||
num_tasks = MAX_TASKS_TO_VISUALIZE
|
||||
|
@ -299,9 +285,8 @@ def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
|
|||
# Helper function that guarantees unique and writeable temp files.
|
||||
# Prevents clashes in task trace files when multiple notebooks are running.
|
||||
def _get_temp_file_path(**kwargs):
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False,
|
||||
dir=os.getcwd(),
|
||||
**kwargs)
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
delete=False, dir=os.getcwd(), **kwargs)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
return os.path.relpath(temp_file_path)
|
||||
|
@ -319,22 +304,16 @@ def task_timeline():
|
|||
disabled=False,
|
||||
)
|
||||
obj_dep = widgets.Checkbox(
|
||||
value=True,
|
||||
disabled=False,
|
||||
layout=widgets.Layout(width='20px')
|
||||
)
|
||||
value=True, disabled=False, layout=widgets.Layout(width='20px'))
|
||||
task_dep = widgets.Checkbox(
|
||||
value=True,
|
||||
disabled=False,
|
||||
layout=widgets.Layout(width='20px')
|
||||
)
|
||||
value=True, disabled=False, layout=widgets.Layout(width='20px'))
|
||||
# Labels to bypass width limitation for descriptions.
|
||||
label_tasks = widgets.Label(value='Task submissions',
|
||||
layout=widgets.Layout(width='110px'))
|
||||
label_objects = widgets.Label(value='Object dependencies',
|
||||
layout=widgets.Layout(width='130px'))
|
||||
label_options = widgets.Label(value='View options:',
|
||||
layout=widgets.Layout(width='100px'))
|
||||
label_tasks = widgets.Label(
|
||||
value='Task submissions', layout=widgets.Layout(width='110px'))
|
||||
label_objects = widgets.Label(
|
||||
value='Object dependencies', layout=widgets.Layout(width='130px'))
|
||||
label_options = widgets.Label(
|
||||
value='View options:', layout=widgets.Layout(width='100px'))
|
||||
start_box, end_box, range_slider, time_opt = get_sliders(False)
|
||||
display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects]))
|
||||
display(widgets.HBox([label_options, breakdown_opt]))
|
||||
|
@ -344,7 +323,8 @@ def task_timeline():
|
|||
# current working directory if it is not present.
|
||||
if not os.path.exists("trace_viewer_full.html"):
|
||||
shutil.copy(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../core/src/catapult_files/trace_viewer_full.html"),
|
||||
"trace_viewer_full.html")
|
||||
|
||||
|
@ -357,8 +337,8 @@ def task_timeline():
|
|||
elif breakdown_opt.value == breakdown_task:
|
||||
breakdown = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected breakdown value '{}'".format(breakdown_opt.value))
|
||||
raise ValueError("Unexpected breakdown value '{}'".format(
|
||||
breakdown_opt.value))
|
||||
|
||||
low, high = map(lambda x: x / 100., range_slider.value)
|
||||
|
||||
|
@ -366,26 +346,24 @@ def task_timeline():
|
|||
diff = largest - smallest
|
||||
|
||||
if time_opt.value == total_time_value:
|
||||
tasks = _truncated_task_profiles(start=smallest + diff * low,
|
||||
end=smallest + diff * high)
|
||||
tasks = _truncated_task_profiles(
|
||||
start=smallest + diff * low, end=smallest + diff * high)
|
||||
elif time_opt.value == total_tasks_value:
|
||||
if range_slider.value[0] == 0:
|
||||
tasks = _truncated_task_profiles(num_tasks=int(
|
||||
num_tasks * high),
|
||||
fwd=True)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=int(num_tasks * high), fwd=True)
|
||||
else:
|
||||
tasks = _truncated_task_profiles(num_tasks=int(
|
||||
num_tasks * (high - low)),
|
||||
fwd=False)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=int(num_tasks * (high - low)), fwd=False)
|
||||
else:
|
||||
raise ValueError("Unexpected time value '{}'".format(
|
||||
time_opt.value))
|
||||
# Write trace to a JSON file
|
||||
print("Collected profiles for {} tasks.".format(len(tasks)))
|
||||
print(
|
||||
"Dumping task profile data to {}, "
|
||||
print("Dumping task profile data to {}, "
|
||||
"this might take a while...".format(json_tmp))
|
||||
ray.global_state.dump_catapult_trace(json_tmp,
|
||||
ray.global_state.dump_catapult_trace(
|
||||
json_tmp,
|
||||
tasks,
|
||||
breakdowns=breakdown,
|
||||
obj_dep=obj_dep.value,
|
||||
|
@ -415,8 +393,7 @@ def task_timeline():
|
|||
|
||||
# Display the task trace within the Jupyter notebook
|
||||
clear_output(wait=True)
|
||||
print(
|
||||
"To view fullscreen, open chrome://tracing in Google Chrome "
|
||||
print("To view fullscreen, open chrome://tracing in Google Chrome "
|
||||
"and load `{}`".format(json_tmp))
|
||||
display(IFrame(html_file_path, 900, 800))
|
||||
|
||||
|
@ -432,36 +409,41 @@ def task_completion_time_distribution():
|
|||
output_notebook(resources=CDN)
|
||||
|
||||
# Create the Bokeh plot
|
||||
p = figure(title="Task Completion Time Distribution",
|
||||
p = figure(
|
||||
title="Task Completion Time Distribution",
|
||||
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=(0, 1),
|
||||
y_range=(0, 1))
|
||||
|
||||
# Create the data source that the plot pulls from
|
||||
source = ColumnDataSource(data={
|
||||
"top": [],
|
||||
"left": [],
|
||||
"right": []
|
||||
})
|
||||
source = ColumnDataSource(data={"top": [], "left": [], "right": []})
|
||||
|
||||
# Plot the histogram rectangles
|
||||
p.quad(top="top", bottom=0, left="left", right="right", source=source,
|
||||
fill_color="#B3B3B3", line_color="#033649")
|
||||
p.quad(
|
||||
top="top",
|
||||
bottom=0,
|
||||
left="left",
|
||||
right="right",
|
||||
source=source,
|
||||
fill_color="#B3B3B3",
|
||||
line_color="#033649")
|
||||
|
||||
# Label the plot axes
|
||||
p.xaxis.axis_label = "Duration in seconds"
|
||||
p.yaxis.axis_label = "Number of tasks"
|
||||
|
||||
handle = show(gridplot(p, ncols=1,
|
||||
handle = show(
|
||||
gridplot(
|
||||
p,
|
||||
ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"), notebook_handle=True)
|
||||
toolbar_location="below"),
|
||||
notebook_handle=True)
|
||||
|
||||
# Function to update the plot
|
||||
def task_completion_time_update(abs_earliest,
|
||||
abs_latest,
|
||||
abs_num_tasks,
|
||||
def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks,
|
||||
tasks):
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
|
@ -469,8 +451,8 @@ def task_completion_time_distribution():
|
|||
# Create the distribution to plot
|
||||
distr = []
|
||||
for task_id, data in tasks.items():
|
||||
distr.append(data["store_outputs_end"] -
|
||||
data["get_arguments_start"])
|
||||
distr.append(
|
||||
data["store_outputs_end"] - data["get_arguments_start"])
|
||||
|
||||
# Create a histogram from the distribution
|
||||
top, bin_edges = np.histogram(distr, bins="auto")
|
||||
|
@ -480,8 +462,8 @@ def task_completion_time_distribution():
|
|||
source.data = {"top": top, "left": left, "right": right}
|
||||
|
||||
# Set the x and y ranges
|
||||
x_range = (min(left) if len(left) else 0,
|
||||
max(right) if len(right) else 1)
|
||||
x_range = (min(left) if len(left) else 0, max(right)
|
||||
if len(right) else 1)
|
||||
y_range = (0, max(top) + 1 if len(top) else 1)
|
||||
|
||||
x_range = helpers._get_range(x_range)
|
||||
|
@ -517,8 +499,7 @@ def compute_utilizations(abs_earliest,
|
|||
latest_time = 0
|
||||
for task_id, data in tasks.items():
|
||||
latest_time = max((latest_time, data["store_outputs_end"]))
|
||||
earliest_time = min((earliest_time,
|
||||
data["get_arguments_start"]))
|
||||
earliest_time = min((earliest_time, data["get_arguments_start"]))
|
||||
|
||||
# Add some epsilon to latest_time to ensure that the end time of the
|
||||
# last task falls __within__ a bucket, and not on the edge
|
||||
|
@ -533,37 +514,37 @@ def compute_utilizations(abs_earliest,
|
|||
task_start_time = data["get_arguments_start"]
|
||||
task_end_time = data["store_outputs_end"]
|
||||
|
||||
start_bucket = int((task_start_time - earliest_time) /
|
||||
bucket_time_length)
|
||||
end_bucket = int((task_end_time - earliest_time) /
|
||||
bucket_time_length)
|
||||
start_bucket = int(
|
||||
(task_start_time - earliest_time) / bucket_time_length)
|
||||
end_bucket = int((task_end_time - earliest_time) / bucket_time_length)
|
||||
# Walk over each time bucket that this task intersects, adding the
|
||||
# amount of time that the task intersects within each bucket
|
||||
for bucket_idx in range(start_bucket, end_bucket + 1):
|
||||
bucket_start_time = ((earliest_time + bucket_idx) *
|
||||
bucket_time_length)
|
||||
bucket_end_time = ((earliest_time + (bucket_idx + 1)) *
|
||||
bucket_time_length)
|
||||
bucket_start_time = ((
|
||||
earliest_time + bucket_idx) * bucket_time_length)
|
||||
bucket_end_time = ((earliest_time +
|
||||
(bucket_idx + 1)) * bucket_time_length)
|
||||
|
||||
task_start_time_within_bucket = max(task_start_time,
|
||||
bucket_start_time)
|
||||
task_end_time_within_bucket = min(task_end_time,
|
||||
bucket_end_time)
|
||||
task_cpu_time_within_bucket = (task_end_time_within_bucket -
|
||||
task_start_time_within_bucket)
|
||||
task_end_time_within_bucket = min(task_end_time, bucket_end_time)
|
||||
task_cpu_time_within_bucket = (
|
||||
task_end_time_within_bucket - task_start_time_within_bucket)
|
||||
|
||||
if bucket_idx > -1 and bucket_idx < num_buckets:
|
||||
cpu_time[bucket_idx] += task_cpu_time_within_bucket
|
||||
|
||||
# Cpu_utilization is the average cpu utilization of the bucket, which
|
||||
# is just cpu_time divided by bucket_time_length.
|
||||
cpu_utilization = list(map(lambda x: x / float(bucket_time_length),
|
||||
cpu_time))
|
||||
cpu_utilization = list(
|
||||
map(lambda x: x / float(bucket_time_length), cpu_time))
|
||||
|
||||
# Generate histogram bucket edges. Subtract out abs_earliest to get
|
||||
# relative time.
|
||||
all_edges = [earliest_time - abs_earliest + i * bucket_time_length
|
||||
for i in range(num_buckets + 1)]
|
||||
all_edges = [
|
||||
earliest_time - abs_earliest + i * bucket_time_length
|
||||
for i in range(num_buckets + 1)
|
||||
]
|
||||
# Left edges are all but the rightmost edge, right edges are all but
|
||||
# the leftmost edge.
|
||||
left_edges = all_edges[:-1]
|
||||
|
@ -591,22 +572,20 @@ def cpu_usage():
|
|||
# Update the plot based on the sliders
|
||||
def plot_utilization():
|
||||
# Create the Bokeh plot
|
||||
time_series_fig = figure(title="CPU Utilization",
|
||||
tools=["save", "hover", "wheel_zoom",
|
||||
"box_zoom", "pan"],
|
||||
time_series_fig = figure(
|
||||
title="CPU Utilization",
|
||||
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=[0, 1],
|
||||
y_range=[0, 1])
|
||||
|
||||
# Create the data source that the plot will pull from
|
||||
time_series_source = ColumnDataSource(data=dict(
|
||||
left=[],
|
||||
right=[],
|
||||
top=[]
|
||||
))
|
||||
time_series_source = ColumnDataSource(
|
||||
data=dict(left=[], right=[], top=[]))
|
||||
|
||||
# Plot the rectangles representing the distribution
|
||||
time_series_fig.quad(left="left",
|
||||
time_series_fig.quad(
|
||||
left="left",
|
||||
right="right",
|
||||
top="top",
|
||||
bottom=0,
|
||||
|
@ -618,27 +597,28 @@ def cpu_usage():
|
|||
time_series_fig.xaxis.axis_label = "Time in seconds"
|
||||
time_series_fig.yaxis.axis_label = "Number of CPUs used"
|
||||
|
||||
handle = show(gridplot(time_series_fig,
|
||||
handle = show(
|
||||
gridplot(
|
||||
time_series_fig,
|
||||
ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"), notebook_handle=True)
|
||||
toolbar_location="below"),
|
||||
notebook_handle=True)
|
||||
|
||||
def update_plot(abs_earliest, abs_latest, abs_num_tasks, tasks):
|
||||
num_buckets = 100
|
||||
left, right, top = compute_utilizations(abs_earliest,
|
||||
abs_latest,
|
||||
abs_num_tasks,
|
||||
tasks,
|
||||
num_buckets)
|
||||
left, right, top = compute_utilizations(
|
||||
abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets)
|
||||
|
||||
time_series_source.data = {"left": left,
|
||||
time_series_source.data = {
|
||||
"left": left,
|
||||
"right": right,
|
||||
"top": top}
|
||||
"top": top
|
||||
}
|
||||
|
||||
x_range = (max(0, min(left))
|
||||
if len(left) else 0,
|
||||
max(right) if len(right) else 1)
|
||||
x_range = (max(0, min(left)) if len(left) else 0, max(right)
|
||||
if len(right) else 1)
|
||||
y_range = (0, max(top) + 1 if len(top) else 1)
|
||||
|
||||
# Define the axis ranges
|
||||
|
@ -654,6 +634,7 @@ def cpu_usage():
|
|||
push_notebook(handle=handle)
|
||||
|
||||
get_sliders(update_plot)
|
||||
|
||||
plot_utilization()
|
||||
|
||||
|
||||
|
@ -672,27 +653,26 @@ def cluster_usage():
|
|||
output_notebook(resources=CDN)
|
||||
|
||||
# Initial values
|
||||
source = ColumnDataSource(data={"node_ip_address": ['127.0.0.1'],
|
||||
source = ColumnDataSource(
|
||||
data={
|
||||
"node_ip_address": ['127.0.0.1'],
|
||||
"time": ['0.5'],
|
||||
"num_tasks": ['1'],
|
||||
"length": [1]})
|
||||
"length": [1]
|
||||
})
|
||||
|
||||
# Define the color schema
|
||||
colors = ["#75968f",
|
||||
"#a5bab7",
|
||||
"#c9d9d3",
|
||||
"#e2e2e2",
|
||||
"#dfccce",
|
||||
"#ddb7b1",
|
||||
"#cc7878",
|
||||
"#933b41",
|
||||
"#550b1d"]
|
||||
colors = [
|
||||
"#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1",
|
||||
"#cc7878", "#933b41", "#550b1d"
|
||||
]
|
||||
mapper = LinearColorMapper(palette=colors, low=0, high=2)
|
||||
|
||||
TOOLS = "hover, save, xpan, box_zoom, reset, xwheel_zoom"
|
||||
|
||||
# Create the plot
|
||||
p = figure(title="Cluster Usage",
|
||||
p = figure(
|
||||
title="Cluster Usage",
|
||||
y_range=list(set(source.data['node_ip_address'])),
|
||||
x_axis_location="above",
|
||||
plot_width=900,
|
||||
|
@ -709,13 +689,21 @@ def cluster_usage():
|
|||
p.xaxis.major_label_orientation = np.pi / 3
|
||||
|
||||
# Plot rectangles
|
||||
p.rect(x="time", y="node_ip_address", width="length", height=1,
|
||||
p.rect(
|
||||
x="time",
|
||||
y="node_ip_address",
|
||||
width="length",
|
||||
height=1,
|
||||
source=source,
|
||||
fill_color={"field": "num_tasks", "transform": mapper},
|
||||
fill_color={
|
||||
"field": "num_tasks",
|
||||
"transform": mapper
|
||||
},
|
||||
line_color=None)
|
||||
|
||||
# Add legend to the side of the plot
|
||||
color_bar = ColorBar(color_mapper=mapper,
|
||||
color_bar = ColorBar(
|
||||
color_mapper=mapper,
|
||||
major_label_text_font_size="8pt",
|
||||
ticker=BasicTicker(desired_num_ticks=len(colors)),
|
||||
label_standoff=6,
|
||||
|
@ -724,11 +712,10 @@ def cluster_usage():
|
|||
p.add_layout(color_bar, "right")
|
||||
|
||||
# Define hover tool
|
||||
p.select_one(HoverTool).tooltips = [
|
||||
("Node IP Address", "@node_ip_address"),
|
||||
("Number of tasks running", "@num_tasks"),
|
||||
("Time", "@time")
|
||||
]
|
||||
p.select_one(HoverTool).tooltips = [("Node IP Address",
|
||||
"@node_ip_address"),
|
||||
("Number of tasks running",
|
||||
"@num_tasks"), ("Time", "@time")]
|
||||
|
||||
# Define the axis labels
|
||||
p.xaxis.axis_label = "Time in seconds"
|
||||
|
@ -764,12 +751,8 @@ def cluster_usage():
|
|||
num_tasks = []
|
||||
|
||||
for node_ip, task_dict in node_to_tasks.items():
|
||||
left, right, top = compute_utilizations(earliest,
|
||||
latest,
|
||||
abs_num_tasks,
|
||||
task_dict,
|
||||
100,
|
||||
True)
|
||||
left, right, top = compute_utilizations(
|
||||
earliest, latest, abs_num_tasks, task_dict, 100, True)
|
||||
for (l, r, t) in zip(left, right, top):
|
||||
nodes.append(node_ip)
|
||||
times.append((l + r) / 2)
|
||||
|
@ -783,10 +766,12 @@ def cluster_usage():
|
|||
mapper.high = max(max(num_tasks), 1)
|
||||
|
||||
# Update plot with new data based on slider and text box values
|
||||
source.data = {"node_ip_address": nodes,
|
||||
source.data = {
|
||||
"node_ip_address": nodes,
|
||||
"time": times,
|
||||
"num_tasks": num_tasks,
|
||||
"length": lengths}
|
||||
"length": lengths
|
||||
}
|
||||
|
||||
push_notebook(handle=handle)
|
||||
|
||||
|
|
|
@ -7,9 +7,12 @@ import subprocess
|
|||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
def start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
|
@ -33,21 +36,24 @@ def start_global_scheduler(redis_address, node_ip_address,
|
|||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable,
|
||||
"-r", redis_address,
|
||||
"-h", node_ip_address]
|
||||
command = [
|
||||
global_scheduler_executable, "-r", redis_address, "-h", node_ip_address
|
||||
]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
|
|
@ -56,7 +56,6 @@ def new_port():
|
|||
|
||||
|
||||
class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
|
@ -164,17 +163,19 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
return db_client_id
|
||||
|
||||
def test_task_default_resources(self):
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0)
|
||||
task1 = local_scheduler.Task(
|
||||
random_driver_id(), random_function_id(), [random_object_id()], 0,
|
||||
random_task_id(), 0)
|
||||
self.assertEqual(task1.required_resources(), {"CPU": 1})
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
task2 = local_scheduler.Task(
|
||||
random_driver_id(), random_function_id(), [random_object_id()], 0,
|
||||
random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_OBJECT_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
0, 0, [], {"CPU": 1, "GPU": 2})
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], {
|
||||
"CPU": 1,
|
||||
"GPU": 2
|
||||
})
|
||||
self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2})
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
|
@ -189,7 +190,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
assert (db_client_id is not None)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
|
@ -227,9 +228,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
if len(task_entries) == 1:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
self.assertTrue(task_status in [
|
||||
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED
|
||||
])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
else:
|
||||
|
@ -258,17 +260,14 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
data_size = np.random.randint(1 << 12)
|
||||
metadata_size = np.random.randint(1 << 9)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True)
|
||||
object_dep, memory_buffer, metadata = create_object(
|
||||
plasma_client, data_size, metadata_size, seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to
|
||||
# yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = local_scheduler.Task(
|
||||
random_driver_id(),
|
||||
random_function_id(),
|
||||
random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep.binary())],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
|
@ -281,12 +280,18 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
task_statuses = [
|
||||
task_entry["State"]
|
||||
for task_entry in task_entries.values()
|
||||
]
|
||||
self.assertTrue(
|
||||
all([
|
||||
status in [
|
||||
state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
state.TASK_STATUS_QUEUED
|
||||
] for status in task_statuses
|
||||
]))
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(
|
||||
state.TASK_STATUS_SCHEDULED)
|
||||
|
@ -294,12 +299,13 @@ class TestGlobalScheduler(unittest.TestCase):
|
|||
state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, "
|
||||
"tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done,
|
||||
num_retries))
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
"tasks queued = {}, retries left = {}".format(
|
||||
len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([
|
||||
status == state.TASK_STATUS_QUEUED
|
||||
for status in task_statuses
|
||||
]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
|
|
|
@ -7,6 +7,8 @@ from ray.core.src.local_scheduler.liblocal_scheduler_library import (
|
|||
task_to_string, _config, common_error)
|
||||
from .local_scheduler_services import start_local_scheduler
|
||||
|
||||
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
|
||||
"task_from_string", "task_to_string", "start_local_scheduler",
|
||||
"_config", "common_error"]
|
||||
__all__ = [
|
||||
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
|
||||
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
|
||||
"common_error"
|
||||
]
|
||||
|
|
|
@ -68,15 +68,15 @@ def start_local_scheduler(plasma_store_name,
|
|||
"provided.")
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
local_scheduler_executable = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
|
||||
command = [local_scheduler_executable,
|
||||
"-s", local_scheduler_name,
|
||||
"-p", plasma_store_name,
|
||||
"-h", node_ip_address,
|
||||
"-n", str(num_workers)]
|
||||
command = [
|
||||
local_scheduler_executable, "-s", local_scheduler_name, "-p",
|
||||
plasma_store_name, "-h", node_ip_address, "-n",
|
||||
str(num_workers)
|
||||
]
|
||||
if plasma_manager_name is not None:
|
||||
command += ["-m", plasma_manager_name]
|
||||
if worker_path is not None:
|
||||
|
@ -88,13 +88,10 @@ def start_local_scheduler(plasma_store_name,
|
|||
"--object-store-name={} "
|
||||
"--object-store-manager-name={} "
|
||||
"--local-scheduler-name={} "
|
||||
"--redis-address={}"
|
||||
.format(sys.executable,
|
||||
worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
plasma_manager_name,
|
||||
local_scheduler_name,
|
||||
"--redis-address={}".format(
|
||||
sys.executable, worker_path,
|
||||
node_ip_address, plasma_store_name,
|
||||
plasma_manager_name, local_scheduler_name,
|
||||
redis_address))
|
||||
command += ["-w", start_worker_command]
|
||||
if redis_address is not None:
|
||||
|
@ -104,27 +101,31 @@ def start_local_scheduler(plasma_store_name,
|
|||
if static_resources is not None:
|
||||
resource_argument = ""
|
||||
for resource_name, resource_quantity in static_resources.items():
|
||||
assert (isinstance(resource_quantity, int) or
|
||||
isinstance(resource_quantity, float))
|
||||
resource_argument = ",".join(
|
||||
[resource_name + "," + str(resource_quantity)
|
||||
for resource_name, resource_quantity in static_resources.items()])
|
||||
assert (isinstance(resource_quantity, int)
|
||||
or isinstance(resource_quantity, float))
|
||||
resource_argument = ",".join([
|
||||
resource_name + "," + str(resource_quantity)
|
||||
for resource_name, resource_quantity in static_resources.items()
|
||||
])
|
||||
else:
|
||||
resource_argument = "CPU,{}".format(psutil.cpu_count())
|
||||
command += ["-c", resource_argument]
|
||||
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
|
|
@ -37,7 +37,6 @@ def random_function_id():
|
|||
|
||||
|
||||
class TestLocalSchedulerClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start Plasma store.
|
||||
plasma_store_name, self.p1 = plasma.start_plasma_store()
|
||||
|
@ -74,34 +73,17 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
|||
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
|
||||
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
|
||||
# Define some arguments to use for the tasks.
|
||||
args_list = [
|
||||
[],
|
||||
[{}],
|
||||
[()],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1],
|
||||
1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [
|
||||
1, 1.3, 1 << 100, "hi", u"hi", [1, 2]
|
||||
], object_ids[:1], object_ids[:2], object_ids[:3],
|
||||
object_ids[:4], object_ids[:5], object_ids[:10],
|
||||
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
|
||||
object_ids[0], "a"
|
||||
], [1, object_ids[0], "a"], [
|
||||
object_ids[0], 1, object_ids[1], "a"
|
||||
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
|
@ -146,6 +128,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
|||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
# Sleep to give the thread time to call get_task.
|
||||
|
@ -170,6 +153,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
|||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
|
||||
|
|
|
@ -26,11 +26,12 @@ class LogMonitor(object):
|
|||
log_file_handles: A dictionary mapping the name of a log file to a file
|
||||
handle for that file.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_ip_address, redis_port, node_ip_address):
|
||||
"""Initialize the log monitor object."""
|
||||
self.node_ip_address = node_ip_address
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
self.log_files = {}
|
||||
self.log_file_handles = {}
|
||||
self.files_to_ignore = set()
|
||||
|
@ -38,9 +39,8 @@ class LogMonitor(object):
|
|||
def update_log_filenames(self):
|
||||
"""Get the most up-to-date list of log files to monitor from Redis."""
|
||||
num_current_log_files = len(self.log_files)
|
||||
new_log_filenames = self.redis_client.lrange(
|
||||
"LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(
|
||||
self.node_ip_address), num_current_log_files, -1)
|
||||
for log_filename in new_log_filenames:
|
||||
print("Beginning to track file {}".format(log_filename))
|
||||
assert log_filename not in self.log_files
|
||||
|
@ -78,8 +78,8 @@ class LogMonitor(object):
|
|||
# Try to open this file for the first time.
|
||||
else:
|
||||
try:
|
||||
self.log_file_handles[log_filename] = open(log_filename,
|
||||
"r")
|
||||
self.log_file_handles[log_filename] = open(
|
||||
log_filename, "r")
|
||||
except IOError as e:
|
||||
if e.errno == os.errno.EMFILE:
|
||||
print("Warning: Ignoring {} because there are too "
|
||||
|
@ -106,12 +106,19 @@ class LogMonitor(object):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
"log monitor to connect "
|
||||
"to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The address to use for Redis.")
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
parser.add_argument(
|
||||
"--node-ip-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The IP address of the node this process is on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -100,8 +100,8 @@ class Monitor(object):
|
|||
self.local_scheduler_id_to_ip_map = dict()
|
||||
self.load_metrics = LoadMetrics()
|
||||
if autoscaling_config:
|
||||
self.autoscaler = StandardAutoscaler(
|
||||
autoscaling_config, self.load_metrics)
|
||||
self.autoscaler = StandardAutoscaler(autoscaling_config,
|
||||
self.load_metrics)
|
||||
else:
|
||||
self.autoscaler = None
|
||||
|
||||
|
@ -160,11 +160,9 @@ class Monitor(object):
|
|||
# task as lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE",
|
||||
hex_to_binary(task_id),
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
|
||||
task["ExecutionDependenciesString"],
|
||||
task["SpillbackCount"])
|
||||
task["ExecutionDependenciesString"], task["SpillbackCount"])
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
|
@ -428,8 +426,8 @@ class Monitor(object):
|
|||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info(
|
||||
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
|
||||
log.info("Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
|
||||
self._clean_up_entries_for_driver(driver_id)
|
||||
|
||||
|
@ -571,7 +569,8 @@ class Monitor(object):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
|
|
|
@ -5,5 +5,6 @@ from __future__ import print_function
|
|||
from ray.plasma.plasma import (start_plasma_store, start_plasma_manager,
|
||||
DEFAULT_PLASMA_STORE_MEMORY)
|
||||
|
||||
__all__ = ["start_plasma_store", "start_plasma_manager",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
__all__ = [
|
||||
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
|
||||
]
|
||||
|
|
|
@ -8,13 +8,13 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
|
||||
__all__ = [
|
||||
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
|
||||
]
|
||||
|
||||
__all__ = ["start_plasma_store", "start_plasma_manager",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
PLASMA_WAIT_TIMEOUT = 2**30
|
||||
|
||||
PLASMA_WAIT_TIMEOUT = 2 ** 30
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10**9
|
||||
|
||||
|
||||
def random_name():
|
||||
|
@ -22,9 +22,12 @@ def random_name():
|
|||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None,
|
||||
plasma_directory=None, huge_pages=False):
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False):
|
||||
"""Start a plasma store process.
|
||||
|
||||
Args:
|
||||
|
@ -48,8 +51,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
|||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
|
||||
if huge_pages and not (sys.platform == "linux" or
|
||||
sys.platform == "linux2"):
|
||||
if huge_pages and not (sys.platform == "linux"
|
||||
or sys.platform == "linux2"):
|
||||
raise Exception("The huge_pages argument is only supported on "
|
||||
"Linux.")
|
||||
|
||||
|
@ -57,29 +60,33 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
|||
raise Exception("If huge_pages is True, then the "
|
||||
"plasma_directory argument must be provided.")
|
||||
|
||||
plasma_store_executable = os.path.join(os.path.abspath(
|
||||
os.path.dirname(__file__)),
|
||||
plasma_store_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
command = [
|
||||
plasma_store_executable, "-s", plasma_store_name, "-m",
|
||||
str(plasma_store_memory)
|
||||
]
|
||||
if plasma_directory is not None:
|
||||
command += ["-d", plasma_directory]
|
||||
if huge_pages:
|
||||
command += ["-h"]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
@ -91,10 +98,14 @@ def new_port():
|
|||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
def start_plasma_manager(store_name, redis_address,
|
||||
node_ip_address="127.0.0.1", plasma_manager_port=None,
|
||||
num_retries=20, use_valgrind=False,
|
||||
run_profiler=False, stdout_file=None,
|
||||
def start_plasma_manager(store_name,
|
||||
redis_address,
|
||||
node_ip_address="127.0.0.1",
|
||||
plasma_manager_port=None,
|
||||
num_retries=20,
|
||||
use_valgrind=False,
|
||||
run_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
|
||||
|
@ -133,27 +144,35 @@ def start_plasma_manager(store_name, redis_address,
|
|||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Plasma manager failed to start, retrying now.")
|
||||
command = [plasma_manager_executable,
|
||||
"-s", store_name,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address,
|
||||
command = [
|
||||
plasma_manager_executable,
|
||||
"-s",
|
||||
store_name,
|
||||
"-m",
|
||||
plasma_manager_name,
|
||||
"-h",
|
||||
node_ip_address,
|
||||
"-p",
|
||||
str(plasma_manager_port),
|
||||
"-r",
|
||||
redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
|
||||
command),
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
else:
|
||||
process = subprocess.Popen(command, stdout=stdout_file,
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen(
|
||||
(["valgrind", "--tool=callgrind"] + command),
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
command, stdout=stdout_file, stderr=stderr_file)
|
||||
# This sleep is critical. If the plasma_manager fails to start because
|
||||
# the port is already in use, then we need it to fail within 0.1
|
||||
# seconds.
|
||||
|
|
|
@ -16,8 +16,8 @@ import unittest
|
|||
# The ray import must come before the pyarrow import because ray modifies the
|
||||
# python path so that the right version of pyarrow is found.
|
||||
import ray
|
||||
from ray.plasma.utils import (random_object_id,
|
||||
create_object_with_id, create_object)
|
||||
from ray.plasma.utils import (random_object_id, create_object_with_id,
|
||||
create_object)
|
||||
from ray import services
|
||||
import pyarrow as pa
|
||||
import pyarrow.plasma as plasma
|
||||
|
@ -30,8 +30,12 @@ def random_name():
|
|||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id,
|
||||
memory_buffer=None, metadata=None):
|
||||
def assert_get_object_equal(unit_test,
|
||||
client1,
|
||||
client2,
|
||||
object_id,
|
||||
memory_buffer=None,
|
||||
metadata=None):
|
||||
client1_buff = client1.get_buffers([object_id])[0]
|
||||
client2_buff = client2.get_buffers([object_id])[0]
|
||||
client1_metadata = client1.get_metadata([object_id])[0]
|
||||
|
@ -39,27 +43,33 @@ def assert_get_object_equal(unit_test, client1, client2, object_id,
|
|||
unit_test.assertEqual(len(client1_buff), len(client2_buff))
|
||||
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
|
||||
# Check that the buffers from the two clients are the same.
|
||||
assert_equal(np.frombuffer(client1_buff, dtype="uint8"),
|
||||
assert_equal(
|
||||
np.frombuffer(client1_buff, dtype="uint8"),
|
||||
np.frombuffer(client2_buff, dtype="uint8"))
|
||||
# Check that the metadata buffers from the two clients are the same.
|
||||
assert_equal(np.frombuffer(client1_metadata, dtype="uint8"),
|
||||
assert_equal(
|
||||
np.frombuffer(client1_metadata, dtype="uint8"),
|
||||
np.frombuffer(client2_metadata, dtype="uint8"))
|
||||
# If a reference buffer was provided, check that it is the same as well.
|
||||
if memory_buffer is not None:
|
||||
assert_equal(np.frombuffer(memory_buffer, dtype="uint8"),
|
||||
assert_equal(
|
||||
np.frombuffer(memory_buffer, dtype="uint8"),
|
||||
np.frombuffer(client1_buff, dtype="uint8"))
|
||||
# If reference metadata was provided, check that it is the same as well.
|
||||
if metadata is not None:
|
||||
assert_equal(np.frombuffer(metadata, dtype="uint8"),
|
||||
assert_equal(
|
||||
np.frombuffer(metadata, dtype="uint8"),
|
||||
np.frombuffer(client1_metadata, dtype="uint8"))
|
||||
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10**9
|
||||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma store process.
|
||||
Args:
|
||||
use_valgrind (bool): True if the plasma store should be started inside
|
||||
|
@ -78,21 +88,25 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
|||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
plasma_store_executable = os.path.join(pa.__path__[0], "plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
command = [
|
||||
plasma_store_executable, "-s", plasma_store_name, "-m",
|
||||
str(plasma_store_memory)
|
||||
]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
@ -104,13 +118,10 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
|||
|
||||
|
||||
class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start two PlasmaStores.
|
||||
store_name1, self.p2 = start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start two PlasmaManagers.
|
||||
|
@ -152,9 +163,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
def test_fetch(self):
|
||||
for _ in range(10):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
self.client1.fetch([object_id1])
|
||||
self.assertEqual(self.client1.contains(object_id1), True)
|
||||
self.assertEqual(self.client2.contains(object_id1), False)
|
||||
|
@ -164,7 +174,10 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
while not self.client2.contains(object_id1):
|
||||
self.client2.fetch([object_id1])
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
|
@ -173,9 +186,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
object_id2 = random_object_id()
|
||||
self.client1.fetch([object_id2])
|
||||
self.assertEqual(self.client1.contains(object_id2), False)
|
||||
memory_buffer2, metadata2 = create_object_with_id(self.client2,
|
||||
object_id2,
|
||||
2000, 2000)
|
||||
memory_buffer2, metadata2 = create_object_with_id(
|
||||
self.client2, object_id2, 2000, 2000)
|
||||
# # Check that the object has been fetched.
|
||||
# self.assertEqual(self.client1.contains(object_id2), True)
|
||||
# Compare the two buffers.
|
||||
|
@ -190,66 +202,86 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
memory_buffer3, metadata3 = create_object_with_id(self.client1,
|
||||
object_id3,
|
||||
2000, 2000)
|
||||
memory_buffer3, metadata3 = create_object_with_id(
|
||||
self.client1, object_id3, 2000, 2000)
|
||||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id3):
|
||||
self.client2.fetch([object_id3])
|
||||
assert_get_object_equal(self, self.client1, self.client2, object_id3,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id3,
|
||||
memory_buffer=memory_buffer3,
|
||||
metadata=metadata3)
|
||||
|
||||
def test_fetch_multiple(self):
|
||||
for _ in range(20):
|
||||
# Create two objects and a third fake one that doesn't exist.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
missing_object_id = random_object_id()
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
object_ids = [object_id1, missing_object_id, object_id2]
|
||||
# Fetch the objects from the other plasma store. The second object
|
||||
# ID should timeout since it does not exist.
|
||||
# TODO(rkn): Right now we must wait for the object table to be
|
||||
# updated.
|
||||
while ((not self.client2.contains(object_id1)) or
|
||||
(not self.client2.contains(object_id2))):
|
||||
while ((not self.client2.contains(object_id1))
|
||||
or (not self.client2.contains(object_id2))):
|
||||
self.client2.fetch(object_ids)
|
||||
# Compare the buffers of the objects that do exist.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
# Fetch in the other direction. The fake object still does not
|
||||
# exist.
|
||||
self.client1.fetch(object_ids)
|
||||
assert_get_object_equal(self, self.client2, self.client1,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(self, self.client2, self.client1,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
|
||||
# Check that we can call fetch with duplicated object IDs.
|
||||
object_id3 = random_object_id()
|
||||
self.client1.fetch([object_id3, object_id3])
|
||||
object_id4, memory_buffer4, metadata4 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id4, memory_buffer4, metadata4 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
time.sleep(0.1)
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id4):
|
||||
self.client2.fetch([object_id3, object_id3, object_id4,
|
||||
object_id4])
|
||||
assert_get_object_equal(self, self.client2, self.client1, object_id4,
|
||||
self.client2.fetch(
|
||||
[object_id3, object_id3, object_id4, object_id4])
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id4,
|
||||
memory_buffer=memory_buffer4,
|
||||
metadata=metadata4)
|
||||
|
||||
|
@ -263,8 +295,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
obj_id1 = random_object_id()
|
||||
self.client1.create(obj_id1, 1000)
|
||||
self.client1.seal(obj_id1)
|
||||
ready, waiting = self.client1.wait([obj_id1], timeout=100,
|
||||
num_returns=1)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id1], timeout=100, num_returns=1)
|
||||
self.assertEqual(set(ready), set([obj_id1]))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
|
@ -273,8 +305,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
obj_id2 = random_object_id()
|
||||
self.client1.create(obj_id2, 1000)
|
||||
# Don't seal.
|
||||
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
|
||||
num_returns=1)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id2, obj_id1], timeout=100, num_returns=1)
|
||||
self.assertEqual(set(ready), set([obj_id1]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
|
@ -287,8 +319,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
|
||||
t = threading.Timer(0.1, finish)
|
||||
t.start()
|
||||
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
|
||||
timeout=1000, num_returns=2)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
|
||||
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
|
@ -319,26 +351,26 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
waiting = object_ids
|
||||
retrieved = []
|
||||
for i in range(1, n + 1):
|
||||
ready, waiting = self.client1.wait(waiting, timeout=1000,
|
||||
num_returns=i)
|
||||
ready, waiting = self.client1.wait(
|
||||
waiting, timeout=1000, num_returns=i)
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client1.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
ready, waiting = self.client1.wait(
|
||||
object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
# Try waiting for all of the object IDs on the second client.
|
||||
waiting = object_ids
|
||||
retrieved = []
|
||||
for i in range(1, n + 1):
|
||||
ready, waiting = self.client2.wait(waiting, timeout=1000,
|
||||
num_returns=i)
|
||||
ready, waiting = self.client2.wait(
|
||||
waiting, timeout=1000, num_returns=i)
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client2.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
ready, waiting = self.client2.wait(
|
||||
object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
|
@ -363,9 +395,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
num_attempts = 100
|
||||
for _ in range(100):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a
|
||||
# race condition on the create and transfer of the object, so keep
|
||||
# trying until the object appears on the second Plasma store.
|
||||
|
@ -379,8 +410,12 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
del buff
|
||||
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
# # Transfer the buffer again.
|
||||
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
|
||||
|
@ -391,8 +426,8 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
# metadata=metadata1)
|
||||
|
||||
# Create an object.
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
|
||||
20000, 20000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(
|
||||
self.client2, 20000, 20000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a
|
||||
# race condition on the create and transfer of the object, so keep
|
||||
# trying until the object appears on the second Plasma store.
|
||||
|
@ -406,8 +441,12 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
del buff
|
||||
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
|
||||
def test_illegal_functionality(self):
|
||||
|
@ -437,7 +476,6 @@ class TestPlasmaManager(unittest.TestCase):
|
|||
|
||||
|
||||
class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start a Plasma store.
|
||||
self.store_name, self.p2 = start_plasma_store(
|
||||
|
@ -446,9 +484,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
|||
self.redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start a PlasmaManagers.
|
||||
manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager(
|
||||
self.store_name,
|
||||
self.redis_address,
|
||||
use_valgrind=USE_VALGRIND)
|
||||
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
|
||||
# Connect a PlasmaClient.
|
||||
self.client = plasma.connect(self.store_name, manager_name, 64)
|
||||
|
||||
|
@ -501,8 +537,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
|||
client2 = plasma.connect(self.store_name, manager_name, 64)
|
||||
ready, waiting = [], object_ids
|
||||
while True:
|
||||
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
|
||||
timeout=0)
|
||||
ready, waiting = client2.wait(
|
||||
object_ids, num_returns=num_objects, timeout=0)
|
||||
if len(ready) == len(object_ids):
|
||||
break
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ def generate_metadata(length):
|
|||
metadata_buffer[0] = random.randint(0, 255)
|
||||
metadata_buffer[-1] = random.randint(0, 255)
|
||||
for _ in range(100):
|
||||
metadata_buffer[random.randint(0, length - 1)] = (
|
||||
random.randint(0, 255))
|
||||
metadata_buffer[random.randint(0, length - 1)] = (random.randint(
|
||||
0, 255))
|
||||
return metadata_buffer
|
||||
|
||||
|
||||
|
@ -32,7 +32,10 @@ def write_to_data_buffer(buff, length):
|
|||
array[random.randint(0, length - 1)] = random.randint(0, 255)
|
||||
|
||||
|
||||
def create_object_with_id(client, object_id, data_size, metadata_size,
|
||||
def create_object_with_id(client,
|
||||
object_id,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True):
|
||||
metadata = generate_metadata(metadata_size)
|
||||
memory_buffer = client.create(object_id, data_size, metadata)
|
||||
|
@ -44,7 +47,6 @@ def create_object_with_id(client, object_id, data_size, metadata_size,
|
|||
|
||||
def create_object(client, data_size, metadata_size, seal=True):
|
||||
object_id = random_object_id()
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id,
|
||||
data_size, metadata_size,
|
||||
seal=seal)
|
||||
memory_buffer, metadata = create_object_with_id(
|
||||
client, object_id, data_size, metadata_size, seal=seal)
|
||||
return object_id, memory_buffer, metadata
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Ray constants used in the Python code."""
|
||||
|
||||
|
||||
# Abort autoscaling if more than this number of errors are encountered. This
|
||||
# is a safety feature to prevent e.g. runaway node launches.
|
||||
AUTOSCALER_MAX_NUM_FAILURES = 5
|
||||
|
|
|
@ -7,8 +7,8 @@ import json
|
|||
import subprocess
|
||||
|
||||
import ray.services as services
|
||||
from ray.autoscaler.commands import (
|
||||
create_or_update_cluster, teardown_cluster, get_head_node_ip)
|
||||
from ray.autoscaler.commands import (create_or_update_cluster,
|
||||
teardown_cluster, get_head_node_ip)
|
||||
|
||||
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_client):
|
||||
|
@ -41,58 +41,116 @@ def cli():
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option("--node-ip-address", required=False, type=str,
|
||||
@click.option(
|
||||
"--node-ip-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the IP address of this node")
|
||||
@click.option("--redis-address", required=False, type=str,
|
||||
@click.option(
|
||||
"--redis-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for connecting to Redis")
|
||||
@click.option("--redis-port", required=False, type=str,
|
||||
@click.option(
|
||||
"--redis-port",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the port to use for starting Redis")
|
||||
@click.option("--num-redis-shards", required=False, type=int,
|
||||
@click.option(
|
||||
"--num-redis-shards",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("the number of additional Redis shards to use in "
|
||||
"addition to the primary Redis shard"))
|
||||
@click.option("--redis-max-clients", required=False, type=int,
|
||||
@click.option(
|
||||
"--redis-max-clients",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("If provided, attempt to configure Redis with this "
|
||||
"maximum number of clients."))
|
||||
@click.option("--redis-shard-ports", required=False, type=str,
|
||||
@click.option(
|
||||
"--redis-shard-ports",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the port to use for the Redis shards other than the "
|
||||
"primary Redis shard")
|
||||
@click.option("--object-manager-port", required=False, type=int,
|
||||
@click.option(
|
||||
"--object-manager-port",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the port to use for starting the object manager")
|
||||
@click.option("--object-store-memory", required=False, type=int,
|
||||
@click.option(
|
||||
"--object-store-memory",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the maximum amount of memory (in bytes) to allow the "
|
||||
"object store to use")
|
||||
@click.option("--num-workers", required=False, type=int,
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("The initial number of workers to start on this node, "
|
||||
"note that the local scheduler may start additional "
|
||||
"workers. If you wish to control the total number of "
|
||||
"concurent tasks, then use --resources instead and "
|
||||
"specify the CPU field."))
|
||||
@click.option("--num-cpus", required=False, type=int,
|
||||
@click.option(
|
||||
"--num-cpus",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the number of CPUs on this node")
|
||||
@click.option("--num-gpus", required=False, type=int,
|
||||
@click.option(
|
||||
"--num-gpus",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the number of GPUs on this node")
|
||||
@click.option("--resources", required=False, default="{}", type=str,
|
||||
@click.option(
|
||||
"--resources",
|
||||
required=False,
|
||||
default="{}",
|
||||
type=str,
|
||||
help="a JSON serialized dictionary mapping resource name to "
|
||||
"resource quantity")
|
||||
@click.option("--head", is_flag=True, default=False,
|
||||
@click.option(
|
||||
"--head",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument for the head node")
|
||||
@click.option("--no-ui", is_flag=True, default=False,
|
||||
@click.option(
|
||||
"--no-ui",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument if the UI should not be started")
|
||||
@click.option("--block", is_flag=True, default=False,
|
||||
@click.option(
|
||||
"--block",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument to block forever in this command")
|
||||
@click.option("--plasma-directory", required=False, type=str,
|
||||
@click.option(
|
||||
"--plasma-directory",
|
||||
required=False,
|
||||
type=str,
|
||||
help="object store directory for memory mapped files")
|
||||
@click.option("--huge-pages", is_flag=True, default=False,
|
||||
@click.option(
|
||||
"--huge-pages",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="enable support for huge pages in the object store")
|
||||
@click.option("--autoscaling-config", required=False, type=str,
|
||||
@click.option(
|
||||
"--autoscaling-config",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the file that contains the autoscaling config")
|
||||
@click.option("--use-raylet", is_flag=True, default=False,
|
||||
@click.option(
|
||||
"--use-raylet",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="use the raylet code path, this is not supported yet")
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients, redis_shard_ports, object_manager_port,
|
||||
object_store_memory, num_workers, num_cpus, num_gpus, resources,
|
||||
head, no_ui, block, plasma_directory, huge_pages,
|
||||
autoscaling_config, use_raylet):
|
||||
head, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
|
||||
use_raylet):
|
||||
# Convert hostnames to numerical IP address.
|
||||
if node_ip_address is not None:
|
||||
node_ip_address = services.address_to_ip(node_ip_address)
|
||||
|
@ -245,33 +303,54 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
|||
|
||||
@click.command()
|
||||
def stop():
|
||||
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler raylet"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler raylet"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the Redis process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PIDs of the worker processes and kill them.
|
||||
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
|
||||
subprocess.call(
|
||||
[
|
||||
"kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the Ray log monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the jupyter process and kill it.
|
||||
try:
|
||||
from notebook.notebookapp import list_running_servers
|
||||
pids = [str(server["pid"]) for server in list_running_servers()
|
||||
if "/tmp/raylogs" in server["notebook_dir"]]
|
||||
subprocess.call(["kill {} 2> /dev/null".format(
|
||||
" ".join(pids))], shell=True)
|
||||
pids = [
|
||||
str(server["pid"]) for server in list_running_servers()
|
||||
if "/tmp/raylogs" in server["notebook_dir"]
|
||||
]
|
||||
subprocess.call(
|
||||
["kill {} 2> /dev/null".format(" ".join(pids))], shell=True)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -279,29 +358,41 @@ def stop():
|
|||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--no-restart", is_flag=True, default=False, help=(
|
||||
"Whether to skip restarting Ray services during the update. "
|
||||
"--no-restart",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Whether to skip restarting Ray services during the update. "
|
||||
"This avoids interrupting running jobs."))
|
||||
@click.option(
|
||||
"--min-workers", required=False, type=int, help=(
|
||||
"Override the configured min worker node count for the cluster."))
|
||||
"--min-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("Override the configured min worker node count for the cluster."))
|
||||
@click.option(
|
||||
"--max-workers", required=False, type=int, help=(
|
||||
"Override the configured max worker node count for the cluster."))
|
||||
"--max-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("Override the configured max worker node count for the cluster."))
|
||||
@click.option(
|
||||
"--yes", "-y", is_flag=True, default=False, help=(
|
||||
"Don't ask for confirmation."))
|
||||
def create_or_update(
|
||||
cluster_config_file, min_workers, max_workers, no_restart, yes):
|
||||
create_or_update_cluster(
|
||||
cluster_config_file, min_workers, max_workers, no_restart, yes)
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Don't ask for confirmation."))
|
||||
def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
|
||||
yes):
|
||||
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
|
||||
no_restart, yes)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--yes", "-y", is_flag=True, default=False, help=(
|
||||
"Don't ask for confirmation."))
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Don't ask for confirmation."))
|
||||
def teardown(cluster_config_file, yes):
|
||||
teardown_cluster(cluster_config_file, yes)
|
||||
|
||||
|
|
|
@ -41,16 +41,12 @@ PROCESS_TYPE_WEB_UI = "web_ui"
|
|||
# important because it determines the order in which these processes will be
|
||||
# terminated when Ray exits, and certain orders will cause errors to be logged
|
||||
# to the screen.
|
||||
all_processes = OrderedDict([(PROCESS_TYPE_MONITOR, []),
|
||||
(PROCESS_TYPE_LOG_MONITOR, []),
|
||||
(PROCESS_TYPE_WORKER, []),
|
||||
(PROCESS_TYPE_RAYLET, []),
|
||||
(PROCESS_TYPE_LOCAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_PLASMA_MANAGER, []),
|
||||
(PROCESS_TYPE_PLASMA_STORE, []),
|
||||
(PROCESS_TYPE_GLOBAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_REDIS_SERVER, []),
|
||||
(PROCESS_TYPE_WEB_UI, [])],)
|
||||
all_processes = OrderedDict(
|
||||
[(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []),
|
||||
(PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []),
|
||||
(PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []),
|
||||
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], )
|
||||
|
||||
# True if processes are run in the valgrind profiler.
|
||||
RUN_RAYLET_PROFILER = False
|
||||
|
@ -82,17 +78,15 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
|
|||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"core/src/ray/raylet/raylet_monitor")
|
||||
RAYLET_EXECUTABLE = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"core/src/ray/raylet/raylet")
|
||||
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
|
||||
|
||||
# ObjectStoreAddress tuples contain all information necessary to connect to an
|
||||
# object store. The fields are:
|
||||
# - name: The socket name for the object store
|
||||
# - manager_name: The socket name for the object store manager
|
||||
# - manager_port: The Internet port that the object store manager listens on
|
||||
ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
|
||||
"manager_name",
|
||||
"manager_port"])
|
||||
ObjectStoreAddress = namedtuple("ObjectStoreAddress",
|
||||
["name", "manager_name", "manager_port"])
|
||||
|
||||
|
||||
def address(ip_address, port):
|
||||
|
@ -133,8 +127,10 @@ def kill_process(p):
|
|||
if p.poll() is not None:
|
||||
# The process has already terminated.
|
||||
return True
|
||||
if any([RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
|
||||
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER]):
|
||||
if any([
|
||||
RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
|
||||
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER
|
||||
]):
|
||||
# Give process signal to write profiler data.
|
||||
os.kill(p.pid, signal.SIGINT)
|
||||
# Wait for profiling data to be written.
|
||||
|
@ -260,8 +256,8 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
|
|||
for log_file in log_files:
|
||||
if log_file is not None:
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
# The name of the key storing the list of log filenames for this IP
|
||||
# address.
|
||||
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
|
||||
|
@ -304,8 +300,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
|||
while counter < num_retries:
|
||||
try:
|
||||
# Run some random command and see if it worked.
|
||||
print("Waiting for redis server at {}:{} to respond..."
|
||||
.format(redis_ip_address, redis_port))
|
||||
print("Waiting for redis server at {}:{} to respond...".format(
|
||||
redis_ip_address, redis_port))
|
||||
redis_client.client_list()
|
||||
except redis.ConnectionError as e:
|
||||
# Wait a little bit.
|
||||
|
@ -427,17 +423,19 @@ def start_credis(node_ip_address,
|
|||
"""
|
||||
|
||||
components = ["credis_master", "credis_head", "credis_tail"]
|
||||
modules = [CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE,
|
||||
CREDIS_MEMBER_MODULE]
|
||||
modules = [
|
||||
CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE
|
||||
]
|
||||
ports = []
|
||||
|
||||
for i, component in enumerate(components):
|
||||
stdout_file, stderr_file = new_log_files(
|
||||
component, redirect_output)
|
||||
stdout_file, stderr_file = new_log_files(component, redirect_output)
|
||||
|
||||
new_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
stdout_file=stdout_file, stderr_file=stderr_file,
|
||||
node_ip_address=node_ip_address,
|
||||
port=port,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file,
|
||||
cleanup=cleanup,
|
||||
module=modules[i],
|
||||
executable=CREDIS_EXECUTABLE)
|
||||
|
@ -456,8 +454,7 @@ def start_credis(node_ip_address,
|
|||
|
||||
# Register credis master in redis
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
|
||||
redis_client.set("credis_address", credis_address)
|
||||
|
||||
return credis_address
|
||||
|
@ -509,9 +506,11 @@ def start_redis(node_ip_address,
|
|||
"number of Redis shards.")
|
||||
|
||||
assigned_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
node_ip_address=node_ip_address,
|
||||
port=port,
|
||||
redis_max_clients=redis_max_clients,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if port is not None:
|
||||
assert assigned_port == port
|
||||
|
@ -540,7 +539,8 @@ def start_redis(node_ip_address,
|
|||
node_ip_address=node_ip_address,
|
||||
port=redis_shard_ports[i],
|
||||
redis_max_clients=redis_max_clients,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if redis_shard_ports[i] is not None:
|
||||
assert redis_shard_port == redis_shard_ports[i]
|
||||
|
@ -601,11 +601,13 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
|||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Redis failed to start, retrying now.")
|
||||
p = subprocess.Popen([executable,
|
||||
"--port", str(port),
|
||||
"--loglevel", "warning",
|
||||
"--loadmodule", module],
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
p = subprocess.Popen(
|
||||
[
|
||||
executable, "--port",
|
||||
str(port), "--loglevel", "warning", "--loadmodule", module
|
||||
],
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
# Check if Redis successfully started (or at least if it the executable
|
||||
# did not exit within 0.1 seconds).
|
||||
|
@ -652,8 +654,8 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
|||
# Increase the hard and soft limits for the redis client pubsub buffer to
|
||||
# 128MB. This is a hack to make it less likely for pubsub messages to be
|
||||
# dropped and for pubsub connections to therefore be killed.
|
||||
cur_config = (redis_client.config_get("client-output-buffer-limit")
|
||||
["client-output-buffer-limit"])
|
||||
cur_config = (redis_client.config_get("client-output-buffer-limit")[
|
||||
"client-output-buffer-limit"])
|
||||
cur_config_list = cur_config.split()
|
||||
assert len(cur_config_list) == 12
|
||||
cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"]
|
||||
|
@ -662,13 +664,17 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
|||
# Put a time stamp in Redis to indicate when it was started.
|
||||
redis_client.set("redis_start_time", time.time())
|
||||
# Record the log files in Redis.
|
||||
record_log_files_in_redis(address(node_ip_address, port), node_ip_address,
|
||||
record_log_files_in_redis(
|
||||
address(node_ip_address, port), node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
return port, p
|
||||
|
||||
|
||||
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=cleanup):
|
||||
def start_log_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=cleanup):
|
||||
"""Start a log monitor process.
|
||||
|
||||
Args:
|
||||
|
@ -684,20 +690,25 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
|||
Python process that imported services exits.
|
||||
"""
|
||||
log_monitor_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"log_monitor.py")
|
||||
p = subprocess.Popen([sys.executable, "-u", log_monitor_filepath,
|
||||
"--redis-address", redis_address,
|
||||
"--node-ip-address", node_ip_address],
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
|
||||
p = subprocess.Popen(
|
||||
[
|
||||
sys.executable, "-u", log_monitor_filepath, "--redis-address",
|
||||
redis_address, "--node-ip-address", node_ip_address
|
||||
],
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
def start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
|
@ -712,7 +723,8 @@ def start_global_scheduler(redis_address, node_ip_address,
|
|||
then this process will be killed by services.cleanup() when the
|
||||
Python process that imported services exits.
|
||||
"""
|
||||
p = global_scheduler.start_global_scheduler(redis_address,
|
||||
p = global_scheduler.start_global_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
|
@ -737,8 +749,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
|
|||
"""
|
||||
new_env = os.environ.copy()
|
||||
notebook_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"WebUI.ipynb")
|
||||
os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb")
|
||||
# We copy the notebook file so that the original doesn't get modified by
|
||||
# the user.
|
||||
random_ui_id = random.randint(0, 100000)
|
||||
|
@ -759,19 +770,23 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
|
|||
# We generate the token used for authentication ourselves to avoid
|
||||
# querying the jupyter server.
|
||||
token = binascii.hexlify(os.urandom(24)).decode("ascii")
|
||||
command = ["jupyter", "notebook", "--no-browser",
|
||||
"--port={}".format(port),
|
||||
command = [
|
||||
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
|
||||
"--NotebookApp.iopub_data_rate_limit=10000000000",
|
||||
"--NotebookApp.open_browser=False",
|
||||
"--NotebookApp.token={}".format(token)]
|
||||
"--NotebookApp.token={}".format(token)
|
||||
]
|
||||
# If the user is root, add the --allow-root flag.
|
||||
if os.geteuid() == 0:
|
||||
command.append("--allow-root")
|
||||
|
||||
try:
|
||||
ui_process = subprocess.Popen(command, env=new_env,
|
||||
ui_process = subprocess.Popen(
|
||||
command,
|
||||
env=new_env,
|
||||
cwd=new_notebook_directory,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
except Exception:
|
||||
print("Failed to start the UI, you may need to run "
|
||||
"'pip install jupyter'.")
|
||||
|
@ -836,8 +851,8 @@ def start_local_scheduler(redis_address,
|
|||
|
||||
# Check that the number of GPUs that the local scheduler wants doesn't
|
||||
# excede the amount allowed by CUDA_VISIBLE_DEVICES.
|
||||
if ("GPU" in resources and gpu_ids is not None and
|
||||
resources["GPU"] > len(gpu_ids)):
|
||||
if ("GPU" in resources and gpu_ids is not None
|
||||
and resources["GPU"] > len(gpu_ids)):
|
||||
raise Exception("Attempting to start local scheduler with {} GPUs, "
|
||||
"but CUDA_VISIBLE_DEVICES contains {}.".format(
|
||||
resources["GPU"], gpu_ids))
|
||||
|
@ -906,21 +921,14 @@ def start_raylet(redis_address,
|
|||
"--node-ip-address={} "
|
||||
"--object-store-name={} "
|
||||
"--raylet-name={} "
|
||||
"--redis-address={}"
|
||||
.format(sys.executable,
|
||||
worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
raylet_name,
|
||||
redis_address))
|
||||
"--redis-address={}".format(
|
||||
sys.executable, worker_path, node_ip_address,
|
||||
plasma_store_name, raylet_name, redis_address))
|
||||
|
||||
command = [RAYLET_EXECUTABLE,
|
||||
raylet_name,
|
||||
plasma_store_name,
|
||||
node_ip_address,
|
||||
gcs_ip_address,
|
||||
gcs_port,
|
||||
start_worker_command]
|
||||
command = [
|
||||
RAYLET_EXECUTABLE, raylet_name, plasma_store_name, node_ip_address,
|
||||
gcs_ip_address, gcs_port, start_worker_command
|
||||
]
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
||||
if cleanup:
|
||||
|
@ -931,12 +939,18 @@ def start_raylet(redis_address,
|
|||
return raylet_name
|
||||
|
||||
|
||||
def start_objstore(node_ip_address, redis_address,
|
||||
object_manager_port=None, store_stdout_file=None,
|
||||
store_stderr_file=None, manager_stdout_file=None,
|
||||
manager_stderr_file=None, objstore_memory=None,
|
||||
cleanup=True, plasma_directory=None,
|
||||
huge_pages=False, use_raylet=False):
|
||||
def start_objstore(node_ip_address,
|
||||
redis_address,
|
||||
object_manager_port=None,
|
||||
store_stdout_file=None,
|
||||
store_stderr_file=None,
|
||||
manager_stdout_file=None,
|
||||
manager_stderr_file=None,
|
||||
objstore_memory=None,
|
||||
cleanup=True,
|
||||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
use_raylet=False):
|
||||
"""This method starts an object store process.
|
||||
|
||||
Args:
|
||||
|
@ -1049,9 +1063,15 @@ def start_objstore(node_ip_address, redis_address,
|
|||
plasma_manager_port)
|
||||
|
||||
|
||||
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
local_scheduler_name, redis_address, worker_path,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
def start_worker(node_ip_address,
|
||||
object_store_name,
|
||||
object_store_manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address,
|
||||
worker_path,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""This method starts a worker process.
|
||||
|
||||
Args:
|
||||
|
@ -1072,14 +1092,14 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
|||
Python process that imported services exits. This is True by
|
||||
default.
|
||||
"""
|
||||
command = [sys.executable,
|
||||
"-u",
|
||||
worker_path,
|
||||
command = [
|
||||
sys.executable, "-u", worker_path,
|
||||
"--node-ip-address=" + node_ip_address,
|
||||
"--object-store-name=" + object_store_name,
|
||||
"--object-store-manager-name=" + object_store_manager_name,
|
||||
"--local-scheduler-name=" + local_scheduler_name,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
"--redis-address=" + str(redis_address)
|
||||
]
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_WORKER].append(p)
|
||||
|
@ -1087,8 +1107,12 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
|||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True, autoscaling_config=None):
|
||||
def start_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True,
|
||||
autoscaling_config=None):
|
||||
"""Run a process to monitor the other processes.
|
||||
|
||||
Args:
|
||||
|
@ -1105,12 +1129,12 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
|||
default.
|
||||
autoscaling_config: path to autoscaling config file.
|
||||
"""
|
||||
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"monitor.py")
|
||||
command = [sys.executable,
|
||||
"-u",
|
||||
monitor_path,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
monitor_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
|
||||
command = [
|
||||
sys.executable, "-u", monitor_path,
|
||||
"--redis-address=" + str(redis_address)
|
||||
]
|
||||
if autoscaling_config:
|
||||
command.append("--autoscaling-config=" + str(autoscaling_config))
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
@ -1120,8 +1144,10 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
|||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_raylet_monitor(redis_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
def start_raylet_monitor(redis_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Run a process to monitor the other processes.
|
||||
|
||||
Args:
|
||||
|
@ -1136,9 +1162,7 @@ def start_raylet_monitor(redis_address, stdout_file=None,
|
|||
default.
|
||||
"""
|
||||
gcs_ip_address, gcs_port = redis_address.split(":")
|
||||
command = [RAYLET_MONITOR_EXECUTABLE,
|
||||
gcs_ip_address,
|
||||
gcs_port]
|
||||
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_MONITOR].append(p)
|
||||
|
@ -1238,15 +1262,16 @@ def start_ray_processes(address_info=None,
|
|||
workers_per_local_scheduler = []
|
||||
for resource_dict in resources:
|
||||
cpus = resource_dict.get("CPU")
|
||||
workers_per_local_scheduler.append(cpus if cpus is not None
|
||||
else psutil.cpu_count())
|
||||
workers_per_local_scheduler.append(cpus if cpus is not None else
|
||||
psutil.cpu_count())
|
||||
|
||||
if address_info is None:
|
||||
address_info = {}
|
||||
address_info["node_ip_address"] = node_ip_address
|
||||
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
worker_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"workers/default_worker.py")
|
||||
|
||||
# Start Redis if there isn't already an instance running. TODO(rkn): We are
|
||||
|
@ -1257,7 +1282,8 @@ def start_ray_processes(address_info=None,
|
|||
redis_shards = address_info.get("redis_shards", [])
|
||||
if redis_address is None:
|
||||
redis_address, redis_shards = start_redis(
|
||||
node_ip_address, port=redis_port,
|
||||
node_ip_address,
|
||||
port=redis_port,
|
||||
redis_shard_ports=redis_shard_ports,
|
||||
num_redis_shards=num_redis_shards,
|
||||
redis_max_clients=redis_max_clients,
|
||||
|
@ -1274,14 +1300,16 @@ def start_ray_processes(address_info=None,
|
|||
# Start monitoring the processes.
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files(
|
||||
"monitor", redirect_output)
|
||||
start_monitor(redis_address,
|
||||
start_monitor(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup,
|
||||
autoscaling_config=autoscaling_config)
|
||||
if use_raylet:
|
||||
start_raylet_monitor(redis_address,
|
||||
start_raylet_monitor(
|
||||
redis_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
|
@ -1289,8 +1317,8 @@ def start_ray_processes(address_info=None,
|
|||
if redis_shards == []:
|
||||
# Get redis shards from primary redis instance.
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
redis_shards = [shard.decode("ascii") for shard in redis_shards]
|
||||
address_info["redis_shards"] = redis_shards
|
||||
|
@ -1299,7 +1327,8 @@ def start_ray_processes(address_info=None,
|
|||
if include_log_monitor:
|
||||
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
|
||||
"log_monitor", redirect_output=True)
|
||||
start_log_monitor(redis_address,
|
||||
start_log_monitor(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=log_monitor_stdout_file,
|
||||
stderr_file=log_monitor_stderr_file,
|
||||
|
@ -1309,7 +1338,8 @@ def start_ray_processes(address_info=None,
|
|||
if include_global_scheduler and not use_raylet:
|
||||
global_scheduler_stdout_file, global_scheduler_stderr_file = (
|
||||
new_log_files("global_scheduler", redirect_output))
|
||||
start_global_scheduler(redis_address,
|
||||
start_global_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=global_scheduler_stdout_file,
|
||||
stderr_file=global_scheduler_stderr_file,
|
||||
|
@ -1324,9 +1354,8 @@ def start_ray_processes(address_info=None,
|
|||
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
|
||||
|
||||
# Get the ports to use for the object managers if any are provided.
|
||||
object_manager_ports = (address_info["object_manager_ports"]
|
||||
if "object_manager_ports" in address_info
|
||||
else None)
|
||||
object_manager_ports = (address_info["object_manager_ports"] if
|
||||
"object_manager_ports" in address_info else None)
|
||||
if not isinstance(object_manager_ports, list):
|
||||
object_manager_ports = num_local_schedulers * [object_manager_ports]
|
||||
assert len(object_manager_ports) == num_local_schedulers
|
||||
|
@ -1347,7 +1376,8 @@ def start_ray_processes(address_info=None,
|
|||
manager_stdout_file=plasma_manager_stdout_file,
|
||||
manager_stderr_file=plasma_manager_stderr_file,
|
||||
objstore_memory=object_store_memory,
|
||||
cleanup=cleanup, plasma_directory=plasma_directory,
|
||||
cleanup=cleanup,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages,
|
||||
use_raylet=use_raylet)
|
||||
object_store_addresses.append(object_store_address)
|
||||
|
@ -1355,8 +1385,8 @@ def start_ray_processes(address_info=None,
|
|||
|
||||
# Start any local schedulers that do not yet exist.
|
||||
if not use_raylet:
|
||||
for i in range(len(local_scheduler_socket_names),
|
||||
num_local_schedulers):
|
||||
for i in range(
|
||||
len(local_scheduler_socket_names), num_local_schedulers):
|
||||
# Connect the local scheduler to the object store at the same
|
||||
# index.
|
||||
object_store_address = object_store_addresses[i]
|
||||
|
@ -1374,7 +1404,8 @@ def start_ray_processes(address_info=None,
|
|||
# redirect the worker output, then we cannot redirect the local
|
||||
# scheduler output.
|
||||
local_scheduler_stdout_file, local_scheduler_stderr_file = (
|
||||
new_log_files("local_scheduler_{}".format(i),
|
||||
new_log_files(
|
||||
"local_scheduler_{}".format(i),
|
||||
redirect_output=redirect_worker_output))
|
||||
local_scheduler_name = start_local_scheduler(
|
||||
redis_address,
|
||||
|
@ -1398,17 +1429,18 @@ def start_ray_processes(address_info=None,
|
|||
else:
|
||||
# Start the raylet. TODO(rkn): Modify this to allow starting
|
||||
# multiple raylets on the same machine.
|
||||
raylet_stdout_file, raylet_stderr_file = (
|
||||
new_log_files("raylet_{}".format(i),
|
||||
redirect_output=redirect_output))
|
||||
address_info["raylet_socket_names"] = [start_raylet(
|
||||
raylet_stdout_file, raylet_stderr_file = (new_log_files(
|
||||
"raylet_{}".format(i), redirect_output=redirect_output))
|
||||
address_info["raylet_socket_names"] = [
|
||||
start_raylet(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
object_store_addresses[i].name,
|
||||
worker_path,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=cleanup)]
|
||||
cleanup=cleanup)
|
||||
]
|
||||
|
||||
if not use_raylet:
|
||||
# Start any workers that the local scheduler has not already started.
|
||||
|
@ -1419,7 +1451,8 @@ def start_ray_processes(address_info=None,
|
|||
for j in range(num_local_scheduler_workers):
|
||||
worker_stdout_file, worker_stderr_file = new_log_files(
|
||||
"worker_{}_{}".format(i, j), redirect_output)
|
||||
start_worker(node_ip_address,
|
||||
start_worker(
|
||||
node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
local_scheduler_name,
|
||||
|
@ -1431,13 +1464,14 @@ def start_ray_processes(address_info=None,
|
|||
workers_per_local_scheduler[i] -= 1
|
||||
|
||||
# Make sure that we've started all the workers.
|
||||
assert(sum(workers_per_local_scheduler) == 0)
|
||||
assert (sum(workers_per_local_scheduler) == 0)
|
||||
|
||||
# Try to start the web UI.
|
||||
if include_webui:
|
||||
ui_stdout_file, ui_stderr_file = new_log_files(
|
||||
"webui", redirect_output=True)
|
||||
address_info["webui_url"] = start_ui(redis_address,
|
||||
address_info["webui_url"] = start_ui(
|
||||
redis_address,
|
||||
stdout_file=ui_stdout_file,
|
||||
stderr_file=ui_stderr_file,
|
||||
cleanup=cleanup)
|
||||
|
@ -1500,9 +1534,12 @@ def start_ray_node(node_ip_address,
|
|||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
address_info = {"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports}
|
||||
return start_ray_processes(address_info=address_info,
|
||||
address_info = {
|
||||
"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports
|
||||
}
|
||||
return start_ray_processes(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
|
|
|
@ -7,11 +7,10 @@ import funcsigs
|
|||
|
||||
from ray.utils import is_cython
|
||||
|
||||
FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
|
||||
"arg_defaults",
|
||||
"arg_is_positionals",
|
||||
"keyword_names",
|
||||
"function_name"])
|
||||
FunctionSignature = namedtuple("FunctionSignature", [
|
||||
"arg_names", "arg_defaults", "arg_is_positionals", "keyword_names",
|
||||
"function_name"
|
||||
])
|
||||
"""This class is used to represent a function signature.
|
||||
|
||||
Attributes:
|
||||
|
@ -49,13 +48,16 @@ def get_signature_params(func):
|
|||
# The first condition for Cython functions, the latter for Cython instance
|
||||
# methods
|
||||
if is_cython(func):
|
||||
attrs = ["__code__", "__annotations__",
|
||||
"__defaults__", "__kwdefaults__"]
|
||||
attrs = [
|
||||
"__code__", "__annotations__", "__defaults__", "__kwdefaults__"
|
||||
]
|
||||
|
||||
if all([hasattr(func, attr) for attr in attrs]):
|
||||
original_func = func
|
||||
|
||||
def func(): return
|
||||
def func():
|
||||
return
|
||||
|
||||
for attr in attrs:
|
||||
setattr(func, attr, getattr(original_func, attr))
|
||||
else:
|
||||
|
@ -130,8 +132,8 @@ def extract_signature(func, ignore_first=False):
|
|||
if ignore_first:
|
||||
if len(sig_params) == 0:
|
||||
raise Exception("Methods must take a 'self' argument, but the "
|
||||
"method '{}' does not have one."
|
||||
.format(func.__name__))
|
||||
"method '{}' does not have one.".format(
|
||||
func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
|
||||
# Extract the names of the keyword arguments.
|
||||
|
@ -183,8 +185,8 @@ def extend_args(function_signature, args, kwargs):
|
|||
for keyword_name in kwargs:
|
||||
if keyword_name not in keyword_names:
|
||||
raise Exception("The name '{}' is not a valid keyword argument "
|
||||
"for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
"for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
# Fill in the remaining arguments.
|
||||
zipped_info = list(zip(arg_names, arg_defaults,
|
||||
|
@ -201,12 +203,12 @@ def extend_args(function_signature, args, kwargs):
|
|||
# can be omitted.
|
||||
if not is_positional:
|
||||
raise Exception("No value was provided for the argument "
|
||||
"'{}' for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
"'{}' for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
too_many_arguments = (len(args) > len(arg_names) and
|
||||
(len(arg_is_positionals) == 0 or
|
||||
not arg_is_positionals[-1]))
|
||||
too_many_arguments = (len(args) > len(arg_names)
|
||||
and (len(arg_is_positionals) == 0
|
||||
or not arg_is_positionals[-1]))
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
|
|
|
@ -13,6 +13,7 @@ import numpy as np
|
|||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
|
||||
# Test timing
|
||||
|
||||
|
||||
|
@ -25,6 +26,7 @@ def empty_function():
|
|||
def trivial_function():
|
||||
return 1
|
||||
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
|
||||
|
@ -42,6 +44,7 @@ def keyword_fct2(a="hello", b="world"):
|
|||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
|
||||
# Test variable numbers of arguments
|
||||
|
||||
|
||||
|
@ -56,17 +59,21 @@ def varargs_fct2(a, *b):
|
|||
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
|
||||
kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
varargs_and_kwargs_exception_thrown = True
|
||||
|
@ -88,6 +95,7 @@ def throw_exception_fct2():
|
|||
def throw_exception_fct3(x):
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
|
||||
|
||||
# test Python mode
|
||||
|
||||
|
||||
|
@ -101,6 +109,7 @@ def python_mode_g(x):
|
|||
x[0] = 1
|
||||
return x
|
||||
|
||||
|
||||
# test no return values
|
||||
|
||||
|
||||
|
|
|
@ -48,8 +48,8 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20):
|
|||
if num_ready_nodes > num_nodes:
|
||||
# Too many nodes have joined. Something must be wrong.
|
||||
raise Exception("{} nodes have joined the cluster, but we were "
|
||||
"expecting {} nodes.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
"expecting {} nodes.".format(
|
||||
num_ready_nodes, num_nodes))
|
||||
time.sleep(0.1)
|
||||
|
||||
# If we get here then we timed out.
|
||||
|
|
|
@ -9,14 +9,7 @@ from ray.tune.result import TrainingResult
|
|||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TrainingResult",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run_experiments",
|
||||
"Experiment"
|
||||
"Trainable", "TrainingResult", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run_experiments", "Experiment"
|
||||
]
|
||||
|
|
|
@ -35,10 +35,13 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
halving rate, specified by the reduction factor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=100,
|
||||
grace_period=10, reduction_factor=3, brackets=3):
|
||||
def __init__(self,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
max_t=100,
|
||||
grace_period=10,
|
||||
reduction_factor=3,
|
||||
brackets=3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert max_t >= grace_period, "grace_period must be <= max_t!"
|
||||
assert grace_period > 0, "grace_period must be positive!"
|
||||
|
@ -51,8 +54,10 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
self._trial_info = {} # Stores Trial -> Bracket
|
||||
|
||||
# Tracks state for new trial add
|
||||
self._brackets = [_Bracket(
|
||||
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
|
||||
self._brackets = [
|
||||
_Bracket(grace_period, max_t, reduction_factor, s)
|
||||
for s in range(brackets)
|
||||
]
|
||||
self._counter = 0 # for
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
|
@ -60,7 +65,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
sizes = np.array([len(b._rungs) for b in self._brackets])
|
||||
probs = np.e ** (sizes - sizes.max())
|
||||
probs = np.e**(sizes - sizes.max())
|
||||
normalized = probs / probs.sum()
|
||||
idx = np.random.choice(len(self._brackets), p=normalized)
|
||||
self._trial_info[trial.trial_id] = self._brackets[idx]
|
||||
|
@ -71,9 +76,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
action = TrialScheduler.STOP
|
||||
else:
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
action = bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
action = bracket.on_result(trial, getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
if action == TrialScheduler.STOP:
|
||||
self._num_stopped += 1
|
||||
|
@ -81,9 +84,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
bracket.on_result(trial, getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
|
@ -91,8 +92,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def debug_string(self):
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(
|
||||
self._num_stopped)
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
|
||||
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
|
||||
return out
|
||||
|
||||
|
@ -111,6 +111,7 @@ class _Bracket():
|
|||
>>> b.on_result(trial3, 1, 1) # STOP
|
||||
>>> b.cutoff(b._rungs[0][1]) == 2.0
|
||||
"""
|
||||
|
||||
def __init__(self, min_t, max_t, reduction_factor, s):
|
||||
self.rf = reduction_factor
|
||||
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
|
||||
|
@ -140,9 +141,10 @@ class _Bracket():
|
|||
return action
|
||||
|
||||
def debug_str(self):
|
||||
iters = " | ".join(
|
||||
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
for milestone, recorded in self._rungs])
|
||||
iters = " | ".join([
|
||||
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
for milestone, recorded in self._rungs
|
||||
])
|
||||
return "Bracket: " + iters
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
@ -24,8 +23,8 @@ def json_to_resources(data):
|
|||
"Unknown resource type {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0),
|
||||
data.get("extra_cpu", 0), data.get("extra_gpu", 0))
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
|
@ -50,59 +49,85 @@ def make_parser(**kwargs):
|
|||
|
||||
# Note: keep this in sync with rllib/train.py
|
||||
parser.add_argument(
|
||||
"--run", default=None, type=str,
|
||||
"--run",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
parser.add_argument(
|
||||
"--stop", default="{}", type=json.loads,
|
||||
"--stop",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="The stopping criteria, specified in JSON. The keys may be any "
|
||||
"field in TrainingResult, e.g. "
|
||||
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
|
||||
"after 600 seconds or 100k timesteps, whichever is reached first.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
"--config",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams), "
|
||||
"specified in JSON.")
|
||||
parser.add_argument(
|
||||
"--resources", help="Deprecated, use --trial-resources.",
|
||||
type=lambda v: _tune_error(
|
||||
"The `resources` argument is no longer supported. "
|
||||
"Use `trial_resources` or --trial-resources instead."))
|
||||
"--resources",
|
||||
help="Deprecated, use --trial-resources.",
|
||||
type=lambda v: _tune_error("The `resources` argument is no longer "
|
||||
"supported. Use `trial_resources` or "
|
||||
"--trial-resources instead."))
|
||||
parser.add_argument(
|
||||
"--trial-resources", default='{"cpu": 1}', type=json_to_resources,
|
||||
"--trial-resources",
|
||||
default='{"cpu": 1}',
|
||||
type=json_to_resources,
|
||||
help="Machine resources to allocate per trial, e.g. "
|
||||
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
|
||||
"unless you specify them here.")
|
||||
parser.add_argument(
|
||||
"--repeat", default=1, type=int,
|
||||
"--repeat",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of times to repeat each trial.")
|
||||
parser.add_argument(
|
||||
"--local-dir", default=DEFAULT_RESULTS_DIR, type=str,
|
||||
"--local-dir",
|
||||
default=DEFAULT_RESULTS_DIR,
|
||||
type=str,
|
||||
help="Local dir to save training results to. Defaults to '{}'.".format(
|
||||
DEFAULT_RESULTS_DIR))
|
||||
parser.add_argument(
|
||||
"--upload-dir", default="", type=str,
|
||||
"--upload-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq", default=0, type=int,
|
||||
"--checkpoint-freq",
|
||||
default=0,
|
||||
type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--max-failures", default=3, type=int,
|
||||
"--max-failures",
|
||||
default=3,
|
||||
type=int,
|
||||
help="Try to recover a trial from its last checkpoint at least this "
|
||||
"many times. Only applies if checkpointing is enabled.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
"--scheduler",
|
||||
default="FIFO",
|
||||
type=str,
|
||||
help="FIFO (default), MedianStopping, AsyncHyperBand,"
|
||||
"HyperBand, or HyperOpt.")
|
||||
parser.add_argument(
|
||||
"--scheduler-config", default="{}", type=json.loads,
|
||||
"--scheduler-config",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
|
||||
# Note: this currently only makes sense when running a single trial
|
||||
parser.add_argument("--restore", default=None, type=str,
|
||||
parser.add_argument(
|
||||
"--restore",
|
||||
default=None,
|
||||
type=str,
|
||||
help="If specified, restore from this checkpoint.")
|
||||
|
||||
return parser
|
||||
|
|
|
@ -60,18 +60,27 @@ if __name__ == "__main__":
|
|||
# `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
ahb = AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
grace_period=5, max_t=100)
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="episode_reward_mean",
|
||||
grace_period=5,
|
||||
max_t=100)
|
||||
|
||||
run_experiments({
|
||||
run_experiments(
|
||||
{
|
||||
"asynchyperband_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"stop": {
|
||||
"training_iteration": 1 if args.smoke_test else 99999
|
||||
},
|
||||
"repeat": 20,
|
||||
"trial_resources": {"cpu": 1, "gpu": 0},
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=ahb)
|
||||
},
|
||||
scheduler=ahb)
|
||||
|
|
|
@ -59,7 +59,8 @@ if __name__ == "__main__":
|
|||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="episode_reward_mean",
|
||||
max_t=100)
|
||||
|
||||
exp = Experiment(
|
||||
|
|
|
@ -12,8 +12,8 @@ def easy_objective(config, reporter):
|
|||
time.sleep(0.2)
|
||||
reporter(
|
||||
timesteps_total=1,
|
||||
episode_reward_mean=-((config["height"]-14) ** 2
|
||||
+ abs(config["width"]-3)))
|
||||
episode_reward_mean=-(
|
||||
(config["height"] - 14)**2 + abs(config["width"] - 3)))
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
|
@ -34,12 +34,18 @@ if __name__ == '__main__':
|
|||
'height': hp.uniform('height', -100, 100),
|
||||
}
|
||||
|
||||
config = {"my_exp": {
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5 if args.smoke_test else 1000,
|
||||
"stop": {"training_iteration": 1},
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"config": {
|
||||
"space": space}}}
|
||||
"space": space
|
||||
}
|
||||
}
|
||||
}
|
||||
hpo_sched = HyperOptScheduler()
|
||||
|
||||
run_experiments(config, verbose=False, scheduler=hpo_sched)
|
||||
|
|
|
@ -42,8 +42,11 @@ class MyTrainableClass(Trainable):
|
|||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps(
|
||||
{"timestep": self.timestep, "value": self.current_value}))
|
||||
f.write(
|
||||
json.dumps({
|
||||
"timestep": self.timestep,
|
||||
"value": self.current_value
|
||||
}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
|
@ -63,7 +66,8 @@ if __name__ == "__main__":
|
|||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration", reward_attr="episode_reward_mean",
|
||||
time_attr="training_iteration",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
|
@ -74,15 +78,23 @@ if __name__ == "__main__":
|
|||
})
|
||||
|
||||
# Try to find the best factor 1 and factor 2
|
||||
run_experiments({
|
||||
run_experiments(
|
||||
{
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
|
||||
"stop": {
|
||||
"training_iteration": 2 if args.smoke_test else 99999
|
||||
},
|
||||
"repeat": 10,
|
||||
"trial_resources": {"cpu": 1, "gpu": 0},
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
},
|
||||
}
|
||||
}, scheduler=pbt, verbose=False)
|
||||
},
|
||||
scheduler=pbt,
|
||||
verbose=False)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""Example of using PBT with RLlib.
|
||||
|
||||
Note that this requires a cluster with at least 8 GPUs in order for all trials
|
||||
|
@ -30,7 +29,8 @@ if __name__ == "__main__":
|
|||
return config
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=120,
|
||||
resample_probability=0.25,
|
||||
# Specifies the mutations of these hyperparams
|
||||
|
@ -45,26 +45,40 @@ if __name__ == "__main__":
|
|||
custom_explore_fn=explore)
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
run_experiments(
|
||||
{
|
||||
"pbt_humanoid_test": {
|
||||
"run": "PPO",
|
||||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"trial_resources": {"cpu": 4, "gpu": 1},
|
||||
"trial_resources": {
|
||||
"cpu": 4,
|
||||
"gpu": 1
|
||||
},
|
||||
"config": {
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"kl_coeff":
|
||||
1.0,
|
||||
"num_workers":
|
||||
8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {"free_log_std": True},
|
||||
"model": {
|
||||
"free_log_std": True
|
||||
},
|
||||
# These params are tuned from a fixed starting value.
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
"sgd_stepsize": 1e-4,
|
||||
"lambda":
|
||||
0.95,
|
||||
"clip_param":
|
||||
0.2,
|
||||
"sgd_stepsize":
|
||||
1e-4,
|
||||
# These params start off randomly drawn from a set.
|
||||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
|
||||
"num_sgd_iter":
|
||||
lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize":
|
||||
lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
},
|
||||
},
|
||||
}, scheduler=pbt)
|
||||
},
|
||||
scheduler=pbt)
|
||||
|
|
|
@ -29,12 +29,10 @@ from ray.tune import Trainable
|
|||
from ray.tune import TrainingResult
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
|
||||
num_classes = 10
|
||||
|
||||
|
||||
class Cifar10Model(Trainable):
|
||||
|
||||
def _read_data(self):
|
||||
# The data, split between train and test sets:
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
@ -54,27 +52,51 @@ class Cifar10Model(Trainable):
|
|||
x = Input(shape=(32, 32, 3))
|
||||
y = x
|
||||
y = Convolution2D(
|
||||
filters=64, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=64, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=128, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=128, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=256, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=256, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Flatten()(y)
|
||||
|
@ -91,7 +113,8 @@ class Cifar10Model(Trainable):
|
|||
model = self._build_model(x_train.shape[1:])
|
||||
|
||||
opt = tf.keras.optimizers.Adadelta()
|
||||
model.compile(loss="categorical_crossentropy",
|
||||
model.compile(
|
||||
loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
metrics=["accuracy"])
|
||||
self.model = model
|
||||
|
@ -134,8 +157,7 @@ class Cifar10Model(Trainable):
|
|||
|
||||
# loss, accuracy
|
||||
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
|
||||
return TrainingResult(timesteps_this_iter=10,
|
||||
mean_accuracy=accuracy)
|
||||
return TrainingResult(timesteps_this_iter=10, mean_accuracy=accuracy)
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
|
@ -154,15 +176,17 @@ class Cifar10Model(Trainable):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--smoke-test",
|
||||
action="store_true",
|
||||
help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable("train_cifar10", Cifar10Model)
|
||||
train_spec = {
|
||||
"run": "train_cifar10",
|
||||
"trial_resources": {"cpu": 1, "gpu": 1},
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 1
|
||||
},
|
||||
"stop": {
|
||||
"mean_accuracy": 0.80,
|
||||
"timesteps_total": 300,
|
||||
|
@ -170,7 +194,7 @@ if __name__ == "__main__":
|
|||
"config": {
|
||||
"epochs": 1,
|
||||
"batch_size": 64,
|
||||
"lr": grid_search([10 ** -4, 10 ** -5]),
|
||||
"lr": grid_search([10**-4, 10**-5]),
|
||||
"decay": lambda spec: spec.config.lr / 100.0,
|
||||
"dropout": grid_search([0.25, 0.5]),
|
||||
},
|
||||
|
@ -178,17 +202,17 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
if args.smoke_test:
|
||||
train_spec["config"]["lr"] = 10 ** -4
|
||||
train_spec["config"]["lr"] = 10**-4
|
||||
train_spec["config"]["dropout"] = 0.5
|
||||
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy",
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="mean_accuracy",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
"dropout": lambda _: np.random.uniform(0, 1),
|
||||
})
|
||||
|
||||
run_experiments({"pbt_cifar10": train_spec},
|
||||
scheduler=pbt)
|
||||
run_experiments({"pbt_cifar10": train_spec}, scheduler=pbt)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
|
||||
See extensive documentation at
|
||||
|
@ -90,7 +89,7 @@ def deepnn(x):
|
|||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
|
@ -173,7 +172,10 @@ def main(_):
|
|||
batch = mnist.train.next_batch(50)
|
||||
if i % 10 == 0:
|
||||
train_accuracy = accuracy.eval(feed_dict={
|
||||
x: batch[0], y_: batch[1], keep_prob: 1.0})
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 1.0
|
||||
})
|
||||
|
||||
# !!! Report status to ray.tune !!!
|
||||
if status_reporter:
|
||||
|
@ -181,11 +183,17 @@ def main(_):
|
|||
timesteps_total=i, mean_accuracy=train_accuracy)
|
||||
|
||||
print('step %d, training accuracy %g' % (i, train_accuracy))
|
||||
train_step.run(
|
||||
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
|
||||
train_step.run(feed_dict={
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 0.5
|
||||
})
|
||||
|
||||
print('test accuracy %g' % accuracy.eval(feed_dict={
|
||||
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
|
||||
x: mnist.test.images,
|
||||
y_: mnist.test.labels,
|
||||
keep_prob: 1.0
|
||||
}))
|
||||
|
||||
|
||||
# !!! Entrypoint for ray.tune !!!
|
||||
|
@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
|
|||
activation_fn = getattr(tf.nn, config['activation'])
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -228,8 +238,12 @@ if __name__ == '__main__':
|
|||
ray.init()
|
||||
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
run_experiments({'tune_mnist_test': mnist_spec},
|
||||
run_experiments(
|
||||
{
|
||||
'tune_mnist_test': mnist_spec
|
||||
},
|
||||
scheduler=AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="mean_accuracy",
|
||||
max_t=600,))
|
||||
max_t=600,
|
||||
))
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
|
||||
See extensive documentation at
|
||||
|
@ -90,7 +89,7 @@ def deepnn(x):
|
|||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
|
@ -173,7 +172,10 @@ def main(_):
|
|||
batch = mnist.train.next_batch(50)
|
||||
if i % 10 == 0:
|
||||
train_accuracy = accuracy.eval(feed_dict={
|
||||
x: batch[0], y_: batch[1], keep_prob: 1.0})
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 1.0
|
||||
})
|
||||
|
||||
# !!! Report status to ray.tune !!!
|
||||
if status_reporter:
|
||||
|
@ -181,11 +183,17 @@ def main(_):
|
|||
timesteps_total=i, mean_accuracy=train_accuracy)
|
||||
|
||||
print('step %d, training accuracy %g' % (i, train_accuracy))
|
||||
train_step.run(
|
||||
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
|
||||
train_step.run(feed_dict={
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 0.5
|
||||
})
|
||||
|
||||
print('test accuracy %g' % accuracy.eval(feed_dict={
|
||||
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
|
||||
x: mnist.test.images,
|
||||
y_: mnist.test.labels,
|
||||
keep_prob: 1.0
|
||||
}))
|
||||
|
||||
|
||||
# !!! Entrypoint for ray.tune !!!
|
||||
|
@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
|
|||
activation_fn = getattr(tf.nn, config['activation'])
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
See extensive documentation at
|
||||
https://www.tensorflow.org/get_started/mnist/pros
|
||||
|
@ -85,7 +84,7 @@ def setupCNN(x):
|
|||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
|
@ -182,14 +181,18 @@ class TrainMNIST(Trainable):
|
|||
self.sess.run(
|
||||
self.train_step,
|
||||
feed_dict={
|
||||
self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5
|
||||
self.x: batch[0],
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 0.5
|
||||
})
|
||||
|
||||
batch = self.mnist.train.next_batch(50)
|
||||
train_accuracy = self.sess.run(
|
||||
self.accuracy,
|
||||
feed_dict={
|
||||
self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0
|
||||
self.x: batch[0],
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 1.0
|
||||
})
|
||||
|
||||
self.iterations += 1
|
||||
|
@ -219,7 +222,7 @@ if __name__ == '__main__':
|
|||
'time_total_s': 600,
|
||||
},
|
||||
'config': {
|
||||
'learning_rate': lambda spec: 10 ** np.random.uniform(-5, -3),
|
||||
'learning_rate': lambda spec: 10**np.random.uniform(-5, -3),
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
},
|
||||
"repeat": 10,
|
||||
|
@ -231,8 +234,6 @@ if __name__ == '__main__':
|
|||
|
||||
ray.init()
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy",
|
||||
max_t=100)
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100)
|
||||
|
||||
run_experiments(
|
||||
{'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
|
||||
run_experiments({'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
|
||||
|
|
|
@ -35,14 +35,26 @@ class Experiment(object):
|
|||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
"""
|
||||
def __init__(self, name, run, stop=None, config=None,
|
||||
trial_resources=None, repeat=1, local_dir=None,
|
||||
upload_dir="", checkpoint_freq=0, max_failures=3):
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
run,
|
||||
stop=None,
|
||||
config=None,
|
||||
trial_resources=None,
|
||||
repeat=1,
|
||||
local_dir=None,
|
||||
upload_dir="",
|
||||
checkpoint_freq=0,
|
||||
max_failures=3):
|
||||
spec = {
|
||||
"run": run,
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"trial_resources": trial_resources or {"cpu": 1, "gpu": 0},
|
||||
"trial_resources": trial_resources or {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"repeat": repeat,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir,
|
||||
|
|
|
@ -91,8 +91,8 @@ class FunctionRunner(Trainable):
|
|||
for k in self._default_config:
|
||||
if k in scrubbed_config:
|
||||
del scrubbed_config[k]
|
||||
self._runner = _RunnerThread(
|
||||
entrypoint, scrubbed_config, self._status_reporter)
|
||||
self._runner = _RunnerThread(entrypoint, scrubbed_config,
|
||||
self._status_reporter)
|
||||
self._start_time = time.time()
|
||||
self._last_reported_timestep = 0
|
||||
self._runner.start()
|
||||
|
@ -104,8 +104,7 @@ class FunctionRunner(Trainable):
|
|||
|
||||
def _train(self):
|
||||
time.sleep(
|
||||
self.config.get(
|
||||
"script_min_iter_time_s",
|
||||
self.config.get("script_min_iter_time_s",
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None:
|
||||
|
|
|
@ -102,9 +102,8 @@ class HyperOptScheduler(FIFOScheduler):
|
|||
self._hpopt_trials.refresh()
|
||||
|
||||
# Get new suggestion from
|
||||
new_trials = self.algo(
|
||||
new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2 ** 31 - 1))
|
||||
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2**31 - 1))
|
||||
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||
self._hpopt_trials.refresh()
|
||||
new_trial = new_trials[0]
|
||||
|
@ -112,8 +111,11 @@ class HyperOptScheduler(FIFOScheduler):
|
|||
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||
new_cfg.update(suggested_config)
|
||||
|
||||
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())])
|
||||
kv_str = "_".join([
|
||||
"{}={}".format(k,
|
||||
str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())
|
||||
])
|
||||
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
|
||||
|
||||
# Keep this consistent with tune.variant_generator
|
||||
|
@ -166,8 +168,7 @@ class HyperOptScheduler(FIFOScheduler):
|
|||
del self._tune_to_hp[trial]
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
return {"loss": -getattr(result, self._reward_attr),
|
||||
"status": "ok"}
|
||||
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, tid):
|
||||
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
|
||||
|
@ -183,8 +184,9 @@ class HyperOptScheduler(FIFOScheduler):
|
|||
experiments and trials left to run. If self._max_concurrent is None,
|
||||
scheduler will add new trial if there is none that are pending.
|
||||
"""
|
||||
pending = [t for t in trial_runner.get_trials()
|
||||
if t.status == Trial.PENDING]
|
||||
pending = [
|
||||
t for t in trial_runner.get_trials() if t.status == Trial.PENDING
|
||||
]
|
||||
if self._num_trials_left <= 0:
|
||||
return
|
||||
if self._max_concurrent is None:
|
||||
|
|
|
@ -66,9 +66,10 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
mentioned in the original HyperBand paper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=81):
|
||||
def __init__(self,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
max_t=81):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = 3
|
||||
|
@ -78,13 +79,12 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
self._get_n0 = lambda s: int(
|
||||
np.ceil(self._s_max_1/(s+1) * self._eta**s))
|
||||
# bracket initial iterations
|
||||
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
|
||||
self._get_r0 = lambda s: int((max_t * self._eta**(-s)))
|
||||
self._hyperbands = [[]] # list of hyperband iterations
|
||||
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
|
||||
|
||||
# Tracks state for new trial add
|
||||
self._state = {"bracket": None,
|
||||
"band_idx": 0}
|
||||
self._state = {"bracket": None, "band_idx": 0}
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
@ -116,9 +116,9 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
cur_bracket = None
|
||||
else:
|
||||
retry = False
|
||||
cur_bracket = Bracket(
|
||||
self._time_attr, self._get_n0(s), self._get_r0(s),
|
||||
self._max_t_attr, self._eta, s)
|
||||
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
|
||||
self._get_r0(s), self._max_t_attr,
|
||||
self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
|
@ -217,11 +217,11 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
"""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
key=lambda b: b.completion_percentage()):
|
||||
for bracket in sorted(
|
||||
hyperband, key=lambda b: b.completion_percentage()):
|
||||
for trial in bracket.current_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
|
@ -258,6 +258,7 @@ class Bracket():
|
|||
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
|
@ -287,7 +288,8 @@ class Bracket():
|
|||
"""Checks if all iterations have completed.
|
||||
|
||||
TODO(rliaw): also check that `t.iterations == self._r`"""
|
||||
return all(self._get_result_time(result) >= self._cumul_r
|
||||
return all(
|
||||
self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
|
||||
def finished(self):
|
||||
|
@ -379,7 +381,7 @@ class Bracket():
|
|||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
cumulative_r = r
|
||||
for i in range(s+1):
|
||||
for i in range(s + 1):
|
||||
work += int(n) * int(r)
|
||||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
|
@ -389,11 +391,11 @@ class Bracket():
|
|||
|
||||
def __repr__(self):
|
||||
status = ", ".join([
|
||||
"Max Size (n)={}".format(self._n),
|
||||
"Milestone (r)={}".format(self._cumul_r),
|
||||
"completed={:.1%}".format(self.completion_percentage())
|
||||
"Max Size (n)={}".format(self._n), "Milestone (r)={}".format(
|
||||
self._cumul_r), "completed={:.1%}".format(
|
||||
self.completion_percentage())
|
||||
])
|
||||
counts = collections.Counter([t.status for t in self._all_trials])
|
||||
trial_statuses = ", ".join(sorted(
|
||||
["{}: {}".format(k, v) for k, v in counts.items()]))
|
||||
trial_statuses = ", ".join(
|
||||
sorted(["{}: {}".format(k, v) for k, v in counts.items()]))
|
||||
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
|
||||
|
|
|
@ -13,7 +13,6 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
|||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
|
||||
|
@ -69,8 +68,7 @@ class _LogSyncer(object):
|
|||
def sync_now(self, force=False):
|
||||
self.last_sync_time = time.time()
|
||||
if not self.worker_ip:
|
||||
print(
|
||||
"Worker ip unknown, skipping log sync for {}".format(
|
||||
print("Worker ip unknown, skipping log sync for {}".format(
|
||||
self.local_dir))
|
||||
return
|
||||
|
||||
|
@ -80,22 +78,20 @@ class _LogSyncer(object):
|
|||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
if ssh_key is None or ssh_user is None:
|
||||
print(
|
||||
"Error: log sync requires cluster to be setup with "
|
||||
print("Error: log sync requires cluster to be setup with "
|
||||
"`ray create_or_update`.")
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
print("Error: log sync requires rsync to be installed.")
|
||||
return
|
||||
worker_to_local_sync_cmd = (
|
||||
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
worker_to_local_sync_cmd = ((
|
||||
"""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
|
||||
ssh_key, ssh_user, self.worker_ip,
|
||||
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
|
||||
|
||||
if self.remote_dir:
|
||||
local_to_remote_sync_cmd = (
|
||||
"aws s3 sync '{}' '{}'".format(
|
||||
local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format(
|
||||
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
|
|
@ -110,9 +110,9 @@ def to_tf_values(result, path):
|
|||
for attr, value in result.items():
|
||||
if value is not None:
|
||||
if type(value) in [int, float]:
|
||||
values.append(tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
simple_value=value))
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
elif type(value) is dict:
|
||||
values.extend(to_tf_values(value, path + [attr]))
|
||||
return values
|
||||
|
@ -125,8 +125,8 @@ class _TFLogger(Logger):
|
|||
def on_result(self, result):
|
||||
tmp = result._asdict()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", "time_total_s",
|
||||
"timesteps_total"]:
|
||||
"config", "pid", "timestamp", "time_total_s", "timesteps_total"
|
||||
]:
|
||||
del tmp[k] # not useful to tf log these
|
||||
values = to_tf_values(tmp, ["ray", "tune"])
|
||||
train_stats = tf.Summary(value=values)
|
||||
|
|
|
@ -32,10 +32,13 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
time a trial reports. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
grace_period=60.0, min_samples_required=3,
|
||||
hard_stop=True, verbose=True):
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
grace_period=60.0,
|
||||
min_samples_required=3,
|
||||
hard_stop=True,
|
||||
verbose=True):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._stopped_trials = set()
|
||||
self._completed_trials = set()
|
||||
|
@ -103,9 +106,10 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
return np.mean(
|
||||
[getattr(r, self._reward_attr)
|
||||
for r in results if getattr(r, self._time_attr) <= t_max])
|
||||
return np.mean([
|
||||
getattr(r, self._reward_attr) for r in results
|
||||
if getattr(r, self._time_attr) <= t_max
|
||||
])
|
||||
|
||||
def _best_result(self, trial):
|
||||
results = self._results[trial]
|
||||
|
|
|
@ -11,7 +11,6 @@ from ray.tune.trial import Trial
|
|||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.variant_generator import _format_vars
|
||||
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
# the bottom PBT_QUANTILE fraction.
|
||||
PBT_QUANTILE = 0.25
|
||||
|
@ -27,8 +26,7 @@ class PBTTrialState(object):
|
|||
self.last_perturbation_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
return str((
|
||||
self.last_score, self.last_checkpoint,
|
||||
return str((self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
|
||||
|
||||
|
@ -51,11 +49,12 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
|||
config[key] not in distribution:
|
||||
new_config[key] = random.choice(distribution)
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = distribution[
|
||||
max(0, distribution.index(config[key]) - 1)]
|
||||
new_config[key] = distribution[max(
|
||||
0,
|
||||
distribution.index(config[key]) - 1)]
|
||||
else:
|
||||
new_config[key] = distribution[
|
||||
min(len(distribution) - 1,
|
||||
new_config[key] = distribution[min(
|
||||
len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
else:
|
||||
if random.random() < resample_probability:
|
||||
|
@ -70,8 +69,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
|||
new_config = custom_explore_fn(new_config)
|
||||
assert new_config is not None, \
|
||||
"Custom explore fn failed to return new config"
|
||||
print(
|
||||
"[explore] perturbed config from {} -> {}".format(config, new_config))
|
||||
print("[explore] perturbed config from {} -> {}".format(
|
||||
config, new_config))
|
||||
return new_config
|
||||
|
||||
|
||||
|
@ -148,10 +147,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
>>> run_experiments({...}, scheduler=pbt)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0, hyperparam_mutations={},
|
||||
resample_probability=0.25, custom_explore_fn=None):
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0,
|
||||
hyperparam_mutations={},
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None):
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
"You must specify at least one of `hyperparam_mutations` or "
|
||||
|
@ -209,11 +211,10 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
if not new_state.last_checkpoint:
|
||||
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
|
||||
return
|
||||
new_config = explore(
|
||||
trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability, self._custom_explore_fn)
|
||||
print(
|
||||
"[exploit] transferring weights from trial "
|
||||
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability,
|
||||
self._custom_explore_fn)
|
||||
print("[exploit] transferring weights from trial "
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
|
@ -242,9 +243,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
if len(trials) <= 1:
|
||||
return [], []
|
||||
else:
|
||||
return (
|
||||
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
|
||||
return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials) * PBT_QUANTILE)):])
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Ensures all trials get fair share of time (as defined by time_attr).
|
||||
|
|
|
@ -14,7 +14,8 @@ ENV_CREATOR = "env_creator"
|
|||
RLLIB_MODEL = "rllib_model"
|
||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR]
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
|
||||
]
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
|
@ -32,8 +33,8 @@ def register_trainable(name, trainable):
|
|||
if isinstance(trainable, FunctionType):
|
||||
trainable = wrap_function(trainable)
|
||||
if not issubclass(trainable, Trainable):
|
||||
raise TypeError(
|
||||
"Second argument must be convertable to Trainable", trainable)
|
||||
raise TypeError("Second argument must be convertable to Trainable",
|
||||
trainable)
|
||||
_default_registry.register(TRAINABLE_CLASS, name, trainable)
|
||||
|
||||
|
||||
|
@ -46,8 +47,7 @@ def register_env(name, env_creator):
|
|||
"""
|
||||
|
||||
if not isinstance(env_creator, FunctionType):
|
||||
raise TypeError(
|
||||
"Second argument must be a function.", env_creator)
|
||||
raise TypeError("Second argument must be a function.", env_creator)
|
||||
_default_registry.register(ENV_CREATOR, name, env_creator)
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,6 @@ from __future__ import print_function
|
|||
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
"""
|
||||
When using ray.tune with custom training scripts, you must periodically report
|
||||
training status back to Ray by calling reporter(result).
|
||||
|
@ -18,8 +16,9 @@ In RLlib, the supplied algorithms fill in TrainingResult for you.
|
|||
# Where ray.tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results")
|
||||
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
TrainingResult = namedtuple(
|
||||
"TrainingResult",
|
||||
[
|
||||
# (Required) Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
|
||||
|
@ -50,18 +49,21 @@ TrainingResult = namedtuple("TrainingResult", [
|
|||
# (Auto-filled) The negated current training loss.
|
||||
"neg_mean_loss",
|
||||
|
||||
# (Auto-filled) Unique string identifier for this experiment. This id is
|
||||
# preserved across checkpoint / restore calls.
|
||||
# (Auto-filled) Unique string identifier for this experiment.
|
||||
# This id is preserved across checkpoint / restore calls.
|
||||
"experiment_id",
|
||||
|
||||
# (Auto-filled) The index of this training iteration, e.g. call to train().
|
||||
# (Auto-filled) The index of this training iteration,
|
||||
# e.g. call to train().
|
||||
"training_iteration",
|
||||
|
||||
# (Auto-filled) Number of timesteps in the simulator in this iteration.
|
||||
# (Auto-filled) Number of timesteps in the simulator
|
||||
# in this iteration.
|
||||
"timesteps_this_iter",
|
||||
|
||||
# (Auto-filled) Time in seconds this iteration took to run. This may be
|
||||
# overriden in order to override the system-computed time difference.
|
||||
# (Auto-filled) Time in seconds this iteration took to run. This may
|
||||
# be overriden in order to override the system-computed
|
||||
# time difference.
|
||||
"time_this_iter_s",
|
||||
|
||||
# (Auto-filled) Accumulated time in seconds for this entire experiment.
|
||||
|
@ -76,15 +78,16 @@ TrainingResult = namedtuple("TrainingResult", [
|
|||
# (Auto-filled) A UNIX timestamp of when the result was processed.
|
||||
"timestamp",
|
||||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
# (Auto-filled) The hostname of the machine hosting the
|
||||
# training process.
|
||||
"hostname",
|
||||
|
||||
# (Auto-filled) The node ip of the machine hosting the training process.
|
||||
# (Auto-filled) The node ip of the machine hosting the
|
||||
# training process.
|
||||
"node_ip",
|
||||
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
])
|
||||
|
||||
|
||||
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
|
||||
TrainingResult.__new__.__defaults__ = (None, ) * len(TrainingResult._fields)
|
||||
|
|
|
@ -20,7 +20,9 @@ if __name__ == "__main__":
|
|||
run_experiments({
|
||||
"test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1}
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
assert 'ray.rllib' not in sys.modules, "RLlib should not be imported"
|
||||
|
|
|
@ -60,163 +60,209 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
def testRewriteEnv(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments({"foo": {
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"env": "CartPole-v0",
|
||||
}})
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.config["env"], "CartPole-v0")
|
||||
|
||||
def testConfigPurity(self):
|
||||
def train(config, reporter):
|
||||
assert config == {"a": "b"}, config
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
"config": {
|
||||
"a": "b"
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testLogdir(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
"config": {
|
||||
"a": "b"
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testLongFilename(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}})
|
||||
"b" * 50: lambda spec: "long" * 40
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testBadParams(self):
|
||||
def f():
|
||||
run_experiments({"foo": {}})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams2(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "asdf",
|
||||
"bah": "this param is not allowed",
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams3(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": grid_search("invalid grid search"),
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams4(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "asdf",
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams5(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"stop": {"asdf": 1}
|
||||
}})
|
||||
run_experiments({"foo": {"run": "PPO", "stop": {"asdf": 1}}})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams6(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "PPO",
|
||||
"trial_resources": {"asdf": 1}
|
||||
}})
|
||||
"trial_resources": {
|
||||
"asdf": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter()
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
time.sleep(99999)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testAbruptReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
|
@ -228,6 +274,7 @@ class RunExperimentTest(unittest.TestCase):
|
|||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
|
@ -251,6 +298,7 @@ class RunExperimentTest(unittest.TestCase):
|
|||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
|
@ -267,6 +315,7 @@ class RunExperimentTest(unittest.TestCase):
|
|||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
|
@ -306,8 +355,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
self.assertEqual(trials[0].max_failures, 5)
|
||||
self.assertEqual(
|
||||
trials[0].local_dir,
|
||||
self.assertEqual(trials[0].local_dir,
|
||||
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
|
||||
self.assertEqual(trials[1].experiment_tag, "1")
|
||||
|
||||
|
@ -392,11 +440,13 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([
|
||||
"x":
|
||||
grid_search([
|
||||
lambda spec: spec.config.y * 100,
|
||||
lambda spec: spec.config.y * 200
|
||||
]),
|
||||
"y": lambda spec: 1,
|
||||
"y":
|
||||
lambda spec: 1,
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
|
@ -406,7 +456,8 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(generate_trials({
|
||||
list(
|
||||
generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
|
@ -442,12 +493,15 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
|
||||
register_trainable("f1", train)
|
||||
|
||||
experiments = {"foo": {
|
||||
experiments = {
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}}
|
||||
"b" * 50: lambda spec: "long" * 40
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
|
@ -468,12 +522,12 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0, extra_cpu=3, extra_gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -489,12 +543,12 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -518,12 +572,12 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -547,13 +601,13 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trials = [
|
||||
Trial("asdf", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -644,7 +698,9 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
|
@ -675,7 +731,9 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
|
@ -692,7 +750,9 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
|
@ -721,14 +781,17 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
Trial("__fake", **kwargs)
|
||||
]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner.step()
|
||||
|
|
|
@ -19,9 +19,8 @@ _register_all()
|
|||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
training_iteration=int(t))
|
||||
return TrainingResult(
|
||||
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
|
@ -76,8 +75,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
rule.on_trial_result(None, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
|
@ -89,8 +87,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
|
@ -124,8 +121,10 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1,
|
||||
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||
grace_period=0,
|
||||
min_samples_required=1,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss')
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
|
@ -185,7 +184,6 @@ class _MockTrialRunner():
|
|||
|
||||
|
||||
class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
|
||||
|
@ -206,7 +204,10 @@ class HyperbandSuite(unittest.TestCase):
|
|||
"""Default statistics for HyperBand"""
|
||||
sched = HyperBandScheduler()
|
||||
res = {
|
||||
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
|
||||
str(s): {
|
||||
"n": sched._get_n0(s),
|
||||
"r": sched._get_r0(s)
|
||||
}
|
||||
for s in range(sched._s_max_1)
|
||||
}
|
||||
res["max_trials"] = sum(v["n"] for v in res.values())
|
||||
|
@ -298,8 +299,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
# Provides results from 0 to 8 in order, keeping last one running
|
||||
for i, trl in enumerate(trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
if i < current_length - 1:
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
mock_runner.process_action(trl, action)
|
||||
|
@ -321,8 +322,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(1)]["r"]
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
@ -338,8 +339,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(0)]["r"]
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
@ -354,14 +355,12 @@ class HyperbandSuite(unittest.TestCase):
|
|||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_error(mock_runner, t3)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t1,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t2,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
|
@ -371,12 +370,13 @@ class HyperbandSuite(unittest.TestCase):
|
|||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
sched.on_trial_result(mock_runner, t, result(
|
||||
stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.assertEqual(
|
||||
len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
|
@ -390,14 +390,12 @@ class HyperbandSuite(unittest.TestCase):
|
|||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, t3, result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t1,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t2,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
|
@ -407,12 +405,13 @@ class HyperbandSuite(unittest.TestCase):
|
|||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
sched.on_trial_result(mock_runner, t, result(
|
||||
stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.assertEqual(
|
||||
len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
|
@ -426,8 +425,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
action = sched.on_trial_result(mock_runner, t, result(
|
||||
init_units, i))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
|
@ -435,13 +434,13 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
|
||||
|
||||
# Make sure that newly added trial gets fair computation (not just 1)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t,
|
||||
result(init_units, 12)))
|
||||
new_units = init_units + int(init_units * sched._eta)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t,
|
||||
result(new_units, 12)))
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
"""Checking that alternate metrics will pass."""
|
||||
|
@ -539,7 +538,6 @@ class _MockTrial(Trial):
|
|||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
def basicSetup(self, resample_prob=0.0, explore=None):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
|
@ -554,9 +552,12 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
i,
|
||||
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
|
||||
"int_factor": 10})
|
||||
i, {
|
||||
"id_factor": i,
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10
|
||||
})
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
self.assertEqual(
|
||||
|
@ -570,27 +571,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trials = runner.get_trials()
|
||||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# checkpoint: both past interval and upper quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(30, 201)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
# not upper quantile any more
|
||||
|
@ -608,8 +605,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
|
@ -617,8 +613,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" in trials[0].experiment_tag)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
|
@ -627,8 +622,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertTrue("@perturbed" in trials[2].experiment_tag)
|
||||
|
@ -662,7 +656,6 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testPerturbationValues(self):
|
||||
|
||||
def assertProduces(fn, values):
|
||||
random.seed(0)
|
||||
seen = set()
|
||||
|
@ -712,8 +705,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
|
||||
|
||||
def testSchedulesMostBehindTrialToRun(self):
|
||||
|
@ -748,6 +740,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
new_config["id_factor"] = 42
|
||||
new_config["float_factor"] = 43
|
||||
return new_config
|
||||
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
|
@ -774,8 +767,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
return t1, t2
|
||||
|
||||
def testAsyncHBOnComplete(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=1)
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
|
@ -803,8 +795,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBAllCompletes(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=10)
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
|
||||
trials = [Trial("PPO") for i in range(10)]
|
||||
for t in trials:
|
||||
scheduler.on_trial_add(None, t)
|
||||
|
@ -834,8 +825,10 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss', brackets=1)
|
||||
grace_period=1,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss',
|
||||
brackets=1)
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
scheduler.on_trial_add(None, t1)
|
||||
|
|
|
@ -30,16 +30,15 @@ class TuneServerSuite(unittest.TestCase):
|
|||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(
|
||||
launch_web_server=True, server_port=port)
|
||||
self.runner = TrialRunner(launch_web_server=True, server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 3},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
client = TuneClient("localhost:{}".format(port))
|
||||
|
@ -61,7 +60,9 @@ class TuneServerSuite(unittest.TestCase):
|
|||
runner.step()
|
||||
spec = {
|
||||
"run": "__fake",
|
||||
"stop": {"training_iteration": 3},
|
||||
"stop": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"trial_resources": dict(cpu=1, gpu=1),
|
||||
}
|
||||
client.add_trial("test", spec)
|
||||
|
|
|
@ -114,8 +114,8 @@ class Trainable(object):
|
|||
time_this_iter = time.time() - start
|
||||
|
||||
if result.timesteps_this_iter is None:
|
||||
raise TuneError(
|
||||
"Must specify timesteps_this_iter in result", result)
|
||||
raise TuneError("Must specify timesteps_this_iter in result",
|
||||
result)
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
@ -159,10 +159,10 @@ class Trainable(object):
|
|||
"""
|
||||
|
||||
checkpoint_path = self._save(checkpoint_dir or self.logdir)
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
pickle.dump([
|
||||
self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total
|
||||
], open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
|
@ -186,8 +186,10 @@ class Trainable(object):
|
|||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
"checkpoint_name":
|
||||
os.path.basename(checkpoint_prefix),
|
||||
"data":
|
||||
data,
|
||||
})
|
||||
if len(compressed) > 10e6: # getting pretty large
|
||||
print("Checkpoint size is {} bytes".format(len(compressed)))
|
||||
|
|
|
@ -42,12 +42,12 @@ class Resources(
|
|||
__slots__ = ()
|
||||
|
||||
def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0):
|
||||
return super(Resources, cls).__new__(
|
||||
cls, cpu, gpu, extra_cpu, extra_gpu)
|
||||
return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu,
|
||||
extra_gpu)
|
||||
|
||||
def summary_string(self):
|
||||
return "{} CPUs, {} GPUs".format(
|
||||
self.cpu + self.extra_cpu, self.gpu + self.extra_gpu)
|
||||
return "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
|
||||
self.gpu + self.extra_gpu)
|
||||
|
||||
def cpu_total(self):
|
||||
return self.cpu + self.extra_cpu
|
||||
|
@ -77,11 +77,17 @@ class Trial(object):
|
|||
TERMINATED = "TERMINATED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
def __init__(
|
||||
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag="", resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion=None, checkpoint_freq=0,
|
||||
restore_path=None, upload_dir=None, max_failures=0):
|
||||
def __init__(self,
|
||||
trainable_name,
|
||||
config=None,
|
||||
local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag="",
|
||||
resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
The args here take the same meaning as the command line flags defined
|
||||
|
@ -166,15 +172,16 @@ class Trial(object):
|
|||
try:
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1
|
||||
error_file = os.path.join(
|
||||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
error_file = os.path.join(self.logdir, "error_{}.txt".format(
|
||||
date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
stop_tasks.append(self.runner.__ray_terminate__.remote(
|
||||
stop_tasks.append(
|
||||
self.runner.__ray_terminate__.remote(
|
||||
self.runner._ray_actor_id.id()))
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
|
@ -252,12 +259,12 @@ class Trial(object):
|
|||
return '{} pid={}'.format(hostname, pid)
|
||||
|
||||
pieces = [
|
||||
'{} [{}]'.format(
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.hostname, self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)),
|
||||
'{} ts'.format(int(self.last_result.timesteps_total))]
|
||||
'{} [{}]'.format(self._status_string(),
|
||||
location_string(self.last_result.hostname,
|
||||
self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format(
|
||||
int(self.last_result.timesteps_total))
|
||||
]
|
||||
|
||||
if self.last_result.episode_reward_mean is not None:
|
||||
pieces.append('{} rew'.format(
|
||||
|
@ -274,10 +281,8 @@ class Trial(object):
|
|||
return ', '.join(pieces)
|
||||
|
||||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status,
|
||||
", {} failures: {}".format(self.num_failures, self.error_file)
|
||||
if self.error_file else "")
|
||||
return "{}{}".format(self.status, ", {} failures: {}".format(
|
||||
self.num_failures, self.error_file) if self.error_file else "")
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint_path is not None or \
|
||||
|
@ -335,9 +340,8 @@ class Trial(object):
|
|||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
if self.verbose and (
|
||||
terminate or
|
||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||
if self.verbose and (terminate or time.time() - self.last_debug >
|
||||
DEBUG_PRINT_INTERVAL):
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
|
@ -358,8 +362,8 @@ class Trial(object):
|
|||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config, self.logdir, self.upload_dir)
|
||||
self.result_logger = UnifiedLogger(self.config, self.logdir,
|
||||
self.upload_dir)
|
||||
remote_logdir = self.logdir
|
||||
|
||||
def logger_creator(config):
|
||||
|
@ -372,7 +376,8 @@ class Trial(object):
|
|||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
self.runner = cls.remote(
|
||||
config=self.config, registry=ray.tune.registry.get_registry(),
|
||||
config=self.config,
|
||||
registry=ray.tune.registry.get_registry(),
|
||||
logger_creator=logger_creator)
|
||||
|
||||
def set_verbose(self, verbose):
|
||||
|
@ -387,8 +392,8 @@ class Trial(object):
|
|||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
identifier = "{}_{}".format(self.trainable_name,
|
||||
self.config["env"])
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
|
|
|
@ -13,7 +13,6 @@ from ray.tune.web_server import TuneServer
|
|||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
|
@ -39,8 +38,11 @@ class TrialRunner(object):
|
|||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler=None, launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
def __init__(self,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
|
@ -73,8 +75,7 @@ class TrialRunner(object):
|
|||
"""Returns whether all trials have finished running."""
|
||||
|
||||
if self._total_time > self._global_time_limit:
|
||||
print(
|
||||
"Exceeded global time limit {} / {}".format(
|
||||
print("Exceeded global time limit {} / {}".format(
|
||||
self._total_time, self._global_time_limit))
|
||||
return True
|
||||
|
||||
|
@ -98,8 +99,8 @@ class TrialRunner(object):
|
|||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError((
|
||||
"Insufficient cluster resources to launch trial: "
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster only has {} "
|
||||
"available.").format(
|
||||
trial.resources.summary_string(),
|
||||
|
@ -165,24 +166,20 @@ class TrialRunner(object):
|
|||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
for t in sorted(
|
||||
trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
messages.append(" - {}:\t{}".format(t, t.progress_string()))
|
||||
if len(trials) > limit:
|
||||
messages.append(" ... {} more not shown".format(
|
||||
len(trials) - limit))
|
||||
messages.append(
|
||||
" ... {} more not shown".format(len(trials) - limit))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
if self._resources_initialized:
|
||||
messages.append(
|
||||
"Resources used: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu,
|
||||
self._avail_resources.cpu,
|
||||
self._committed_resources.gpu,
|
||||
self._avail_resources.gpu))
|
||||
messages.append("Resources used: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu, self._avail_resources.cpu,
|
||||
self._committed_resources.gpu, self._avail_resources.gpu))
|
||||
return messages
|
||||
|
||||
def has_resources(self, resources):
|
||||
|
@ -190,9 +187,8 @@ class TrialRunner(object):
|
|||
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
return (
|
||||
resources.cpu_total() <= cpu_avail and
|
||||
resources.gpu_total() <= gpu_avail)
|
||||
return (resources.cpu_total() <= cpu_avail
|
||||
and resources.gpu_total() <= gpu_avail)
|
||||
|
||||
def _get_next_trial(self):
|
||||
self._update_avail_resources()
|
||||
|
@ -307,8 +303,9 @@ class TrialRunner(object):
|
|||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
# NOTE: There should only be one...
|
||||
result_id = [rid for rid, t in self._running.items()
|
||||
if t is trial][0]
|
||||
result_id = [
|
||||
rid for rid, t in self._running.items() if t is trial
|
||||
][0]
|
||||
self._running.pop(result_id)
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
|
@ -339,9 +336,8 @@ class TrialRunner(object):
|
|||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry['ClientType'] == 'local_scheduler' and not
|
||||
entry['Deleted'])
|
||||
entry for client in clients.values() for entry in client if
|
||||
(entry['ClientType'] == 'local_scheduler' and not entry['Deleted'])
|
||||
]
|
||||
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
|
||||
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
|
||||
|
|
|
@ -99,12 +99,12 @@ class FIFOScheduler(TrialScheduler):
|
|||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PAUSED and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PAUSED
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ from ray.tune.trial_scheduler import FIFOScheduler
|
|||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.experiment import Experiment
|
||||
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
|
@ -30,13 +29,15 @@ def _make_scheduler(args):
|
|||
if args.scheduler in _SCHEDULERS:
|
||||
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
|
||||
else:
|
||||
raise TuneError(
|
||||
"Unknown scheduler: {}, should be one of {}".format(
|
||||
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
def run_experiments(experiments,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True):
|
||||
"""Tunes experiments.
|
||||
|
||||
Args:
|
||||
|
@ -54,17 +55,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
scheduler = FIFOScheduler()
|
||||
|
||||
runner = TrialRunner(
|
||||
scheduler, launch_web_server=with_server, server_port=server_port,
|
||||
scheduler,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose)
|
||||
exp_list = experiments
|
||||
if isinstance(experiments, Experiment):
|
||||
exp_list = [experiments]
|
||||
elif type(experiments) is dict:
|
||||
exp_list = [Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()]
|
||||
exp_list = [
|
||||
Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()
|
||||
]
|
||||
|
||||
if (type(exp_list) is list and
|
||||
all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
if (type(exp_list) is list
|
||||
and all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
for experiment in exp_list:
|
||||
scheduler.add_experiment(experiment, runner)
|
||||
else:
|
||||
|
|
|
@ -7,7 +7,6 @@ import base64
|
|||
import ray
|
||||
from ray.tune.registry import _to_pinnable, _from_pinnable
|
||||
|
||||
|
||||
_pinned_objects = []
|
||||
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
||||
|
||||
|
@ -15,14 +14,15 @@ PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
|||
def pin_in_object_store(obj):
|
||||
obj_id = ray.put(_to_pinnable(obj))
|
||||
_pinned_objects.append(ray.get(obj_id))
|
||||
return "{}{}".format(
|
||||
PINNED_OBJECT_PREFIX, base64.b64encode(obj_id.id()).decode("utf-8"))
|
||||
return "{}{}".format(PINNED_OBJECT_PREFIX,
|
||||
base64.b64encode(obj_id.id()).decode("utf-8"))
|
||||
|
||||
|
||||
def get_pinned_object(pinned_id):
|
||||
from ray.local_scheduler import ObjectID
|
||||
return _from_pinnable(ray.get(ObjectID(
|
||||
base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
return _from_pinnable(
|
||||
ray.get(
|
||||
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -163,8 +163,8 @@ def _generate_variants(spec):
|
|||
for path, value in grid_vars:
|
||||
resolved_vars[path] = _get_value(spec, path)
|
||||
for k, v in resolved.items():
|
||||
if (k in resolved_vars and v != resolved_vars[k] and
|
||||
_is_resolved(resolved_vars[k])):
|
||||
if (k in resolved_vars and v != resolved_vars[k]
|
||||
and _is_resolved(resolved_vars[k])):
|
||||
raise ValueError(
|
||||
"The variable `{}` could not be unambiguously "
|
||||
"resolved to a single value. Consider simplifying "
|
||||
|
@ -262,16 +262,16 @@ def _unresolved_values(spec):
|
|||
for k, v in spec.items():
|
||||
resolved, v = _try_resolve(v)
|
||||
if not resolved:
|
||||
found[(k,)] = v
|
||||
found[(k, )] = v
|
||||
elif isinstance(v, dict):
|
||||
# Recurse into a dict
|
||||
for (path, value) in _unresolved_values(v).items():
|
||||
found[(k,) + path] = value
|
||||
found[(k, ) + path] = value
|
||||
elif isinstance(v, list):
|
||||
# Recurse into a list
|
||||
for i, elem in enumerate(v):
|
||||
for (path, value) in _unresolved_values({i: elem}).items():
|
||||
found[(k,) + path] = value
|
||||
found[(k, ) + path] = value
|
||||
return found
|
||||
|
||||
|
||||
|
|
|
@ -61,8 +61,10 @@ def _resolve(directory, result_fname):
|
|||
|
||||
|
||||
def load_results_to_df(directory, result_name="result.json"):
|
||||
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
|
||||
for f in files if f == result_name]
|
||||
exp_directories = [
|
||||
dirpath for dirpath, dirs, files in os.walk(directory) for f in files
|
||||
if f == result_name
|
||||
]
|
||||
data = [_resolve(d, result_name) for d in exp_directories]
|
||||
data = [d for d in data if d]
|
||||
return pd.DataFrame(data)
|
||||
|
@ -76,8 +78,9 @@ def generate_plotly_dim_dict(df, field):
|
|||
dim_dict["values"] = column
|
||||
elif is_string_dtype(column):
|
||||
texts = column.unique()
|
||||
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0]
|
||||
for x in column]
|
||||
dim_dict["values"] = [
|
||||
np.argwhere(texts == x).flatten()[0] for x in column
|
||||
]
|
||||
dim_dict["tickvals"] = list(range(len(texts)))
|
||||
dim_dict["ticktext"] = texts
|
||||
else:
|
||||
|
|
|
@ -39,28 +39,30 @@ class TuneClient(object):
|
|||
|
||||
def get_all_trials(self):
|
||||
"""Returns a list of all trials (trial_id, config, status)."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_LIST})
|
||||
return self._get_response({"command": TuneClient.GET_LIST})
|
||||
|
||||
def get_trial(self, trial_id):
|
||||
"""Returns the last result for queried trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id})
|
||||
return self._get_response({
|
||||
"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
|
||||
def add_trial(self, name, trial_spec):
|
||||
"""Adds a trial of `name` with configurations."""
|
||||
# TODO(rliaw): have better way of specifying a new trial
|
||||
return self._get_response(
|
||||
{"command": TuneClient.ADD,
|
||||
return self._get_response({
|
||||
"command": TuneClient.ADD,
|
||||
"name": name,
|
||||
"spec": trial_spec})
|
||||
"spec": trial_spec
|
||||
})
|
||||
|
||||
def stop_trial(self, trial_id):
|
||||
"""Requests to stop trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.STOP,
|
||||
"trial_id": trial_id})
|
||||
return self._get_response({
|
||||
"command": TuneClient.STOP,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
|
||||
def _get_response(self, data):
|
||||
payload = json.dumps(data).encode()
|
||||
|
@ -71,7 +73,6 @@ class TuneClient(object):
|
|||
|
||||
def RunnerHandler(runner):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
|
||||
def do_GET(self):
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
raw_body = self.rfile.read(content_len)
|
||||
|
@ -82,8 +83,7 @@ def RunnerHandler(runner):
|
|||
else:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(
|
||||
response).encode())
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def trial_info(self, trial):
|
||||
if trial.last_result:
|
||||
|
@ -112,8 +112,9 @@ def RunnerHandler(runner):
|
|||
response = {}
|
||||
try:
|
||||
if command == TuneClient.GET_LIST:
|
||||
response["trials"] = [self.trial_info(t)
|
||||
for t in runner.get_trials()]
|
||||
response["trials"] = [
|
||||
self.trial_info(t) for t in runner.get_trials()
|
||||
]
|
||||
elif command == TuneClient.GET_TRIAL:
|
||||
trial = get_trial()
|
||||
response["trial_info"] = self.trial_info(trial)
|
||||
|
@ -147,8 +148,7 @@ class TuneServer(threading.Thread):
|
|||
self._port = port if port else self.DEFAULT_PORT
|
||||
address = ('localhost', self._port)
|
||||
print("Starting Tune Server...")
|
||||
self._server = HTTPServer(
|
||||
address, RunnerHandler(runner))
|
||||
self._server = HTTPServer(address, RunnerHandler(runner))
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -46,7 +46,10 @@ def format_error_message(exception_message, task_exception=False):
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
|
||||
def push_error_to_driver(redis_client,
|
||||
error_type,
|
||||
message,
|
||||
driver_id=None,
|
||||
data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
|
@ -64,9 +67,11 @@ def push_error_to_driver(redis_client, error_type, message, driver_id=None,
|
|||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
|
||||
data = {} if data is None else data
|
||||
redis_client.hmset(error_key, {"type": error_type,
|
||||
redis_client.hmset(error_key, {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
"data": data})
|
||||
"data": data
|
||||
})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
|
||||
|
@ -134,10 +139,8 @@ def hex_to_binary(hex_identifier):
|
|||
return binascii.unhexlify(hex_identifier)
|
||||
|
||||
|
||||
FunctionProperties = collections.namedtuple("FunctionProperties",
|
||||
["num_return_vals",
|
||||
"resources",
|
||||
"max_calls"])
|
||||
FunctionProperties = collections.namedtuple(
|
||||
"FunctionProperties", ["num_return_vals", "resources", "max_calls"])
|
||||
"""FunctionProperties: A named tuple storing remote functions information."""
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -8,34 +8,51 @@ import traceback
|
|||
import ray
|
||||
import ray.actor
|
||||
|
||||
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse addresses for the worker "
|
||||
"to connect to."))
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
parser.add_argument(
|
||||
"--node-ip-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the ip address of the worker's node")
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument("--object-store-name", required=True, type=str,
|
||||
parser.add_argument(
|
||||
"--object-store-name",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the object store's name")
|
||||
parser.add_argument("--object-store-manager-name", required=False, type=str,
|
||||
parser.add_argument(
|
||||
"--object-store-manager-name",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the object store manager's name")
|
||||
parser.add_argument("--local-scheduler-name", required=False, type=str,
|
||||
parser.add_argument(
|
||||
"--local-scheduler-name",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the local scheduler's name")
|
||||
parser.add_argument("--raylet-name", required=False, type=str,
|
||||
help="the raylet's name")
|
||||
|
||||
parser.add_argument(
|
||||
"--raylet-name", required=False, type=str, help="the raylet's name")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
info = {
|
||||
"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name,
|
||||
"raylet_socket_name": args.raylet_name}
|
||||
"raylet_socket_name": args.raylet_name
|
||||
}
|
||||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE,
|
||||
use_raylet=(args.raylet_name is not None))
|
||||
ray.worker.connect(
|
||||
info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None))
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
|
@ -54,8 +71,8 @@ if __name__ == "__main__":
|
|||
traceback_str = traceback.format_exc() + error_explanation
|
||||
# Create a Redis client.
|
||||
redis_client = ray.services.create_redis_client(args.redis_address)
|
||||
ray.utils.push_error_to_driver(redis_client, "worker_crash",
|
||||
traceback_str, driver_id=None)
|
||||
ray.utils.push_error_to_driver(
|
||||
redis_client, "worker_crash", traceback_str, driver_id=None)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
|
|
@ -18,13 +18,11 @@ import setuptools.command.build_ext as _build_ext
|
|||
ray_files = [
|
||||
"ray/core/src/common/thirdparty/redis/src/redis-server",
|
||||
"ray/core/src/common/redis_module/libray_redis_module.so",
|
||||
"ray/core/src/plasma/plasma_store",
|
||||
"ray/core/src/plasma/plasma_manager",
|
||||
"ray/core/src/plasma/plasma_store", "ray/core/src/plasma/plasma_manager",
|
||||
"ray/core/src/local_scheduler/local_scheduler",
|
||||
"ray/core/src/local_scheduler/liblocal_scheduler_library.so",
|
||||
"ray/core/src/global_scheduler/global_scheduler",
|
||||
"ray/core/src/ray/raylet/raylet_monitor",
|
||||
"ray/core/src/ray/raylet/raylet",
|
||||
"ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet",
|
||||
"ray/WebUI.ipynb"
|
||||
]
|
||||
|
||||
|
@ -35,14 +33,14 @@ ray_ui_files = [
|
|||
"ray/core/src/catapult_files/trace_viewer_full.html"
|
||||
]
|
||||
|
||||
ray_autoscaler_files = [
|
||||
"ray/autoscaler/aws/example-full.yaml"
|
||||
]
|
||||
ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"]
|
||||
|
||||
if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
|
||||
ray_files += ["ray/core/src/credis/build/src/libmember.so",
|
||||
ray_files += [
|
||||
"ray/core/src/credis/build/src/libmember.so",
|
||||
"ray/core/src/credis/build/src/libmaster.so",
|
||||
"ray/core/src/credis/redis/src/redis-server"]
|
||||
"ray/core/src/credis/redis/src/redis-server"
|
||||
]
|
||||
|
||||
# The UI files are mandatory if the INCLUDE_UI environment variable equals 1.
|
||||
# Otherwise, they are optional.
|
||||
|
@ -54,9 +52,8 @@ else:
|
|||
optional_ray_files += ray_autoscaler_files
|
||||
|
||||
extras = {
|
||||
"rllib": [
|
||||
"tensorflow", "pyyaml", "gym[atari]", "opencv-python",
|
||||
"lz4", "scipy"]
|
||||
"rllib":
|
||||
["tensorflow", "pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"]
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,8 +70,9 @@ class build_ext(_build_ext.build_ext):
|
|||
pyarrow_files = [
|
||||
os.path.join("ray/pyarrow_files/pyarrow", filename)
|
||||
for filename in os.listdir("./ray/pyarrow_files/pyarrow")
|
||||
if not os.path.isdir(os.path.join("ray/pyarrow_files/pyarrow",
|
||||
filename))]
|
||||
if not os.path.isdir(
|
||||
os.path.join("ray/pyarrow_files/pyarrow", filename))
|
||||
]
|
||||
|
||||
files_to_include = ray_files + pyarrow_files
|
||||
|
||||
|
@ -84,8 +82,8 @@ class build_ext(_build_ext.build_ext):
|
|||
generated_python_directory = "ray/core/generated"
|
||||
for filename in os.listdir(generated_python_directory):
|
||||
if filename[-3:] == ".py":
|
||||
self.move_file(os.path.join(generated_python_directory,
|
||||
filename))
|
||||
self.move_file(
|
||||
os.path.join(generated_python_directory, filename))
|
||||
|
||||
# Try to copy over the optional files.
|
||||
for filename in optional_ray_files:
|
||||
|
@ -114,14 +112,16 @@ class BinaryDistribution(Distribution):
|
|||
return True
|
||||
|
||||
|
||||
setup(name="ray",
|
||||
setup(
|
||||
name="ray",
|
||||
# The version string is also in __init__.py. TODO(pcm): Fix this.
|
||||
version="0.4.0",
|
||||
packages=find_packages(),
|
||||
cmdclass={"build_ext": build_ext},
|
||||
# The BinaryDistribution argument triggers build_ext.
|
||||
distclass=BinaryDistribution,
|
||||
install_requires=["numpy",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"funcsigs",
|
||||
"click",
|
||||
"colorama",
|
||||
|
@ -131,7 +131,8 @@ setup(name="ray",
|
|||
"redis",
|
||||
# The six module is required by pyarrow.
|
||||
"six >= 1.0.0",
|
||||
"flatbuffers"],
|
||||
"flatbuffers"
|
||||
],
|
||||
setup_requires=["cython >= 0.23"],
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
|
||||
|
|
|
@ -15,7 +15,6 @@ import ray.test.test_utils
|
|||
|
||||
|
||||
class ActorAPI(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -39,19 +38,21 @@ class ActorAPI(unittest.TestCase):
|
|||
self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, "ab"))
|
||||
|
||||
actor = Actor.remote(1, 2, "c")
|
||||
self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")),
|
||||
(3, 5, "cd"))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, "cd"))
|
||||
|
||||
actor = Actor.remote(1, arg2="c")
|
||||
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")),
|
||||
(1, 3, "cd"))
|
||||
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(0, arg2="d")), (1, 3, "cd"))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
|
||||
(1, 1, "cd"))
|
||||
|
||||
actor = Actor.remote(1, arg2="c", arg1=2)
|
||||
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")),
|
||||
(1, 4, "cd"))
|
||||
self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(0, arg2="d")), (1, 4, "cd"))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
|
||||
(1, 2, "cd"))
|
||||
|
||||
# Make sure we get an exception if the constructor is called
|
||||
|
@ -84,15 +85,17 @@ class ActorAPI(unittest.TestCase):
|
|||
self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, (), ()))
|
||||
|
||||
actor = Actor.remote(1, 2)
|
||||
self.assertEqual(ray.get(actor.get_values.remote(2, 3)),
|
||||
(3, 5, (), ()))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(2, 3)), (3, 5, (), ()))
|
||||
|
||||
actor = Actor.remote(1, 2, "c")
|
||||
self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")),
|
||||
(3, 5, ("c",), ("d",)))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, ("c", ),
|
||||
("d", )))
|
||||
|
||||
actor = Actor.remote(1, 2, "a", "b", "c", "d")
|
||||
self.assertEqual(ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)),
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)),
|
||||
(3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)))
|
||||
|
||||
@ray.remote
|
||||
|
@ -106,7 +109,7 @@ class ActorAPI(unittest.TestCase):
|
|||
a = Actor.remote()
|
||||
self.assertEqual(ray.get(a.get_values.remote()), ((), ()))
|
||||
a = Actor.remote(1)
|
||||
self.assertEqual(ray.get(a.get_values.remote(2)), ((1,), (2,)))
|
||||
self.assertEqual(ray.get(a.get_values.remote(2)), ((1, ), (2, )))
|
||||
a = Actor.remote(1, 2)
|
||||
self.assertEqual(ray.get(a.get_values.remote(3, 4)), ((1, 2), (3, 4)))
|
||||
|
||||
|
@ -191,6 +194,7 @@ class ActorAPI(unittest.TestCase):
|
|||
|
||||
# This is an invalid way of using the actor decorator.
|
||||
with self.assertRaises(Exception):
|
||||
|
||||
@ray.remote()
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
|
@ -198,6 +202,7 @@ class ActorAPI(unittest.TestCase):
|
|||
|
||||
# This is an invalid way of using the actor decorator.
|
||||
with self.assertRaises(Exception):
|
||||
|
||||
@ray.remote(invalid_kwarg=0) # noqa: F811
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
|
@ -205,6 +210,7 @@ class ActorAPI(unittest.TestCase):
|
|||
|
||||
# This is an invalid way of using the actor decorator.
|
||||
with self.assertRaises(Exception):
|
||||
|
||||
@ray.remote(num_cpus=0, invalid_kwarg=0) # noqa: F811
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
|
@ -300,7 +306,6 @@ class ActorAPI(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorMethods(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -417,8 +422,9 @@ class ActorMethods(unittest.TestCase):
|
|||
results = []
|
||||
# Call each actor's method a bunch of times.
|
||||
for i in range(num_actors):
|
||||
results += [actors[i].increase.remote()
|
||||
for _ in range(num_increases)]
|
||||
results += [
|
||||
actors[i].increase.remote() for _ in range(num_increases)
|
||||
]
|
||||
result_values = ray.get(results)
|
||||
for i in range(num_actors):
|
||||
self.assertEqual(
|
||||
|
@ -440,7 +446,6 @@ class ActorMethods(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorNesting(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -510,6 +515,7 @@ class ActorNesting(unittest.TestCase):
|
|||
|
||||
def get_value(self):
|
||||
return self.x
|
||||
|
||||
self.actor2 = Actor2.remote(z)
|
||||
|
||||
def get_values(self, z):
|
||||
|
@ -556,11 +562,13 @@ class ActorNesting(unittest.TestCase):
|
|||
|
||||
def get_value(self):
|
||||
return self.x
|
||||
|
||||
actor = Actor1.remote(x)
|
||||
return ray.get([actor.get_value.remote() for _ in range(n)])
|
||||
|
||||
self.assertEqual(ray.get(f.remote(3, 1)), [3])
|
||||
self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]),
|
||||
self.assertEqual(
|
||||
ray.get([f.remote(i, 20) for i in range(10)]),
|
||||
[20 * [i] for i in range(10)])
|
||||
|
||||
def testUseActorWithinRemoteFunction(self):
|
||||
|
@ -591,6 +599,7 @@ class ActorNesting(unittest.TestCase):
|
|||
# Export a bunch of remote functions.
|
||||
num_remote_functions = 50
|
||||
for i in range(num_remote_functions):
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return i
|
||||
|
@ -613,7 +622,6 @@ class ActorNesting(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorInheritance(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -646,7 +654,6 @@ class ActorInheritance(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorSchedulingProperties(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -674,7 +681,6 @@ class ActorSchedulingProperties(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorsOnMultipleNodes(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -692,7 +698,9 @@ class ActorsOnMultipleNodes(unittest.TestCase):
|
|||
|
||||
def testActorLoadBalancing(self):
|
||||
num_local_schedulers = 3
|
||||
ray.worker._init(start_ray_local=True, num_workers=0,
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers)
|
||||
|
||||
@ray.remote
|
||||
|
@ -712,13 +720,13 @@ class ActorsOnMultipleNodes(unittest.TestCase):
|
|||
attempts = 0
|
||||
while attempts < num_attempts:
|
||||
actors = [Actor1.remote() for _ in range(num_actors)]
|
||||
locations = ray.get([actor.get_location.remote()
|
||||
for actor in actors])
|
||||
locations = ray.get(
|
||||
[actor.get_location.remote() for actor in actors])
|
||||
names = set(locations)
|
||||
counts = [locations.count(name) for name in names]
|
||||
print("Counts are {}.".format(counts))
|
||||
if (len(names) == num_local_schedulers and
|
||||
all([count >= minimum_count for count in counts])):
|
||||
if (len(names) == num_local_schedulers
|
||||
and all([count >= minimum_count for count in counts])):
|
||||
break
|
||||
attempts += 1
|
||||
self.assertLess(attempts, num_attempts)
|
||||
|
@ -732,18 +740,17 @@ class ActorsOnMultipleNodes(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorsWithGPUs(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Crashing with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Crashing with new GCS API.")
|
||||
def testActorGPUs(self):
|
||||
num_local_schedulers = 3
|
||||
num_gpus_per_scheduler = 4
|
||||
ray.worker._init(
|
||||
start_ray_local=True, num_workers=0,
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
|
||||
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
|
||||
|
@ -760,19 +767,21 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
tuple(self.gpu_ids))
|
||||
|
||||
# Create one actor per GPU.
|
||||
actors = [Actor1.remote() for _
|
||||
in range(num_local_schedulers * num_gpus_per_scheduler)]
|
||||
actors = [
|
||||
Actor1.remote()
|
||||
for _ in range(num_local_schedulers * num_gpus_per_scheduler)
|
||||
]
|
||||
# Make sure that no two actors are assigned to the same GPU.
|
||||
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
|
||||
for actor in actors])
|
||||
locations_and_ids = ray.get(
|
||||
[actor.get_location_and_ids.remote() for actor in actors])
|
||||
node_names = set([location for location, gpu_id in locations_and_ids])
|
||||
self.assertEqual(len(node_names), num_local_schedulers)
|
||||
location_actor_combinations = []
|
||||
for node_name in node_names:
|
||||
for gpu_id in range(num_gpus_per_scheduler):
|
||||
location_actor_combinations.append((node_name, (gpu_id,)))
|
||||
self.assertEqual(set(locations_and_ids),
|
||||
set(location_actor_combinations))
|
||||
location_actor_combinations.append((node_name, (gpu_id, )))
|
||||
self.assertEqual(
|
||||
set(locations_and_ids), set(location_actor_combinations))
|
||||
|
||||
# Creating a new actor should fail because all of the GPUs are being
|
||||
# used.
|
||||
|
@ -784,7 +793,8 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
num_local_schedulers = 3
|
||||
num_gpus_per_scheduler = 5
|
||||
ray.worker._init(
|
||||
start_ray_local=True, num_workers=0,
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
|
||||
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
|
||||
|
@ -803,8 +813,8 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
# Create some actors.
|
||||
actors1 = [Actor1.remote() for _ in range(num_local_schedulers * 2)]
|
||||
# Make sure that no two actors are assigned to the same GPU.
|
||||
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
|
||||
for actor in actors1])
|
||||
locations_and_ids = ray.get(
|
||||
[actor.get_location_and_ids.remote() for actor in actors1])
|
||||
node_names = set([location for location, gpu_id in locations_and_ids])
|
||||
self.assertEqual(len(node_names), num_local_schedulers)
|
||||
|
||||
|
@ -835,11 +845,11 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
# Create some actors.
|
||||
actors2 = [Actor2.remote() for _ in range(num_local_schedulers)]
|
||||
# Make sure that no two actors are assigned to the same GPU.
|
||||
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
|
||||
for actor in actors2])
|
||||
self.assertEqual(node_names,
|
||||
set([location for location, gpu_id
|
||||
in locations_and_ids]))
|
||||
locations_and_ids = ray.get(
|
||||
[actor.get_location_and_ids.remote() for actor in actors2])
|
||||
self.assertEqual(
|
||||
node_names,
|
||||
set([location for location, gpu_id in locations_and_ids]))
|
||||
for location, gpu_ids in locations_and_ids:
|
||||
gpus_in_use[location].extend(gpu_ids)
|
||||
for node_name in node_names:
|
||||
|
@ -855,8 +865,11 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
def testActorDifferentNumbersOfGPUs(self):
|
||||
# Test that we can create actors on two nodes that have different
|
||||
# numbers of GPUs.
|
||||
ray.worker._init(start_ray_local=True, num_workers=0,
|
||||
num_local_schedulers=3, num_cpus=[10, 10, 10],
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=3,
|
||||
num_cpus=[10, 10, 10],
|
||||
num_gpus=[0, 5, 10])
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
|
@ -872,16 +885,19 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
# Create some actors.
|
||||
actors = [Actor1.remote() for _ in range(0 + 5 + 10)]
|
||||
# Make sure that no two actors are assigned to the same GPU.
|
||||
locations_and_ids = ray.get([actor.get_location_and_ids.remote()
|
||||
for actor in actors])
|
||||
locations_and_ids = ray.get(
|
||||
[actor.get_location_and_ids.remote() for actor in actors])
|
||||
node_names = set([location for location, gpu_id in locations_and_ids])
|
||||
self.assertEqual(len(node_names), 2)
|
||||
for node_name in node_names:
|
||||
node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids
|
||||
if location == node_name]
|
||||
node_gpu_ids = [
|
||||
gpu_id for location, gpu_id in locations_and_ids
|
||||
if location == node_name
|
||||
]
|
||||
self.assertIn(len(node_gpu_ids), [5, 10])
|
||||
self.assertEqual(set(node_gpu_ids),
|
||||
set([(i,) for i in range(len(node_gpu_ids))]))
|
||||
self.assertEqual(
|
||||
set(node_gpu_ids),
|
||||
set([(i, ) for i in range(len(node_gpu_ids))]))
|
||||
|
||||
# Creating a new actor should fail because all of the GPUs are being
|
||||
# used.
|
||||
|
@ -893,8 +909,10 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
num_local_schedulers = 10
|
||||
num_gpus_per_scheduler = 10
|
||||
ray.worker._init(
|
||||
start_ray_local=True, num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers, redirect_output=True,
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
redirect_output=True,
|
||||
num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]),
|
||||
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
|
||||
|
||||
|
@ -906,15 +924,17 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
self.gpu_ids = ray.get_gpu_ids()
|
||||
|
||||
def get_location_and_ids(self):
|
||||
return ((ray.worker.global_worker.plasma_client
|
||||
.store_socket_name),
|
||||
tuple(self.gpu_ids))
|
||||
return ((ray.worker.global_worker.plasma_client.
|
||||
store_socket_name), tuple(self.gpu_ids))
|
||||
|
||||
# Create n actors.
|
||||
for _ in range(n):
|
||||
Actor.remote()
|
||||
|
||||
ray.get([create_actors.remote(num_gpus_per_scheduler)
|
||||
for _ in range(num_local_schedulers)])
|
||||
ray.get([
|
||||
create_actors.remote(num_gpus_per_scheduler)
|
||||
for _ in range(num_local_schedulers)
|
||||
])
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class Actor(object):
|
||||
|
@ -936,7 +956,8 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
num_local_schedulers = 3
|
||||
num_gpus_per_scheduler = 6
|
||||
ray.worker._init(
|
||||
start_ray_local=True, num_workers=0,
|
||||
start_ray_local=True,
|
||||
num_workers=0,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=num_gpus_per_scheduler,
|
||||
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
|
||||
|
@ -951,11 +972,11 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
self.assertLess(first_interval[0], first_interval[1])
|
||||
self.assertLess(second_interval[0], second_interval[1])
|
||||
intervals_nonoverlapping = (
|
||||
first_interval[1] <= second_interval[0] or
|
||||
second_interval[1] <= first_interval[0])
|
||||
first_interval[1] <= second_interval[0]
|
||||
or second_interval[1] <= first_interval[0])
|
||||
assert intervals_nonoverlapping, (
|
||||
"Intervals {} and {} are overlapping."
|
||||
.format(first_interval, second_interval))
|
||||
"Intervals {} and {} are overlapping.".format(
|
||||
first_interval, second_interval))
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def f1():
|
||||
|
@ -995,13 +1016,16 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
|
||||
def locations_to_intervals_for_many_tasks():
|
||||
# Launch a bunch of GPU tasks.
|
||||
locations_ids_and_intervals = ray.get(
|
||||
[f1.remote() for _
|
||||
in range(5 * num_local_schedulers * num_gpus_per_scheduler)] +
|
||||
[f2.remote() for _
|
||||
in range(5 * num_local_schedulers * num_gpus_per_scheduler)] +
|
||||
[f1.remote() for _
|
||||
in range(5 * num_local_schedulers * num_gpus_per_scheduler)])
|
||||
locations_ids_and_intervals = ray.get([
|
||||
f1.remote() for _ in range(
|
||||
5 * num_local_schedulers * num_gpus_per_scheduler)
|
||||
] + [
|
||||
f2.remote() for _ in range(
|
||||
5 * num_local_schedulers * num_gpus_per_scheduler)
|
||||
] + [
|
||||
f1.remote() for _ in range(
|
||||
5 * num_local_schedulers * num_gpus_per_scheduler)
|
||||
])
|
||||
|
||||
locations_to_intervals = collections.defaultdict(lambda: [])
|
||||
for location, gpu_ids, interval in locations_ids_and_intervals:
|
||||
|
@ -1012,7 +1036,8 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
# Run a bunch of GPU tasks.
|
||||
locations_to_intervals = locations_to_intervals_for_many_tasks()
|
||||
# Make sure that all GPUs were used.
|
||||
self.assertEqual(len(locations_to_intervals),
|
||||
self.assertEqual(
|
||||
len(locations_to_intervals),
|
||||
num_local_schedulers * num_gpus_per_scheduler)
|
||||
# For each GPU, verify that the set of tasks that used this specific
|
||||
# GPU did not overlap in time.
|
||||
|
@ -1030,7 +1055,8 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
# Run a bunch of GPU tasks.
|
||||
locations_to_intervals = locations_to_intervals_for_many_tasks()
|
||||
# Make sure that all but one of the GPUs were used.
|
||||
self.assertEqual(len(locations_to_intervals),
|
||||
self.assertEqual(
|
||||
len(locations_to_intervals),
|
||||
num_local_schedulers * num_gpus_per_scheduler - 1)
|
||||
# For each GPU, verify that the set of tasks that used this specific
|
||||
# GPU did not overlap in time.
|
||||
|
@ -1041,13 +1067,14 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
|
||||
# Create several more actors that use GPUs.
|
||||
actors = [Actor1.remote() for _ in range(3)]
|
||||
actor_locations = ray.get([actor.get_location_and_ids.remote()
|
||||
for actor in actors])
|
||||
actor_locations = ray.get(
|
||||
[actor.get_location_and_ids.remote() for actor in actors])
|
||||
|
||||
# Run a bunch of GPU tasks.
|
||||
locations_to_intervals = locations_to_intervals_for_many_tasks()
|
||||
# Make sure that all but 11 of the GPUs were used.
|
||||
self.assertEqual(len(locations_to_intervals),
|
||||
self.assertEqual(
|
||||
len(locations_to_intervals),
|
||||
num_local_schedulers * num_gpus_per_scheduler - 1 - 3)
|
||||
# For each GPU, verify that the set of tasks that used this specific
|
||||
# GPU did not overlap in time.
|
||||
|
@ -1059,9 +1086,10 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
self.assertNotIn(location, locations_to_intervals)
|
||||
|
||||
# Create more actors to fill up all the GPUs.
|
||||
more_actors = [Actor1.remote() for _ in
|
||||
range(num_local_schedulers *
|
||||
num_gpus_per_scheduler - 1 - 3)]
|
||||
more_actors = [
|
||||
Actor1.remote() for _ in range(
|
||||
num_local_schedulers * num_gpus_per_scheduler - 1 - 3)
|
||||
]
|
||||
# Wait for the actors to finish being created.
|
||||
ray.get([actor.get_location_and_ids.remote() for actor in more_actors])
|
||||
|
||||
|
@ -1195,16 +1223,17 @@ class ActorsWithGPUs(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorReconstruction(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testLocalSchedulerDying(self):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_workers=0, redirect_output=True)
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_workers=0,
|
||||
redirect_output=True)
|
||||
|
||||
@ray.remote
|
||||
class Counter(object):
|
||||
|
@ -1243,8 +1272,7 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertEqual(results, list(range(1, 1 + len(results))))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testManyLocalSchedulersDying(self):
|
||||
# This test can be made more stressful by increasing the numbers below.
|
||||
# The total number of actors created will be
|
||||
|
@ -1253,9 +1281,11 @@ class ActorReconstruction(unittest.TestCase):
|
|||
num_actors_at_a_time = 3
|
||||
num_function_calls_at_a_time = 10
|
||||
|
||||
ray.worker._init(start_ray_local=True,
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_workers=0, redirect_output=True)
|
||||
num_workers=0,
|
||||
redirect_output=True)
|
||||
|
||||
@ray.remote
|
||||
class SlowCounter(object):
|
||||
|
@ -1281,14 +1311,13 @@ class ActorReconstruction(unittest.TestCase):
|
|||
# a local scheduler, and run some more methods.
|
||||
for i in range(num_local_schedulers - 1):
|
||||
# Create some actors.
|
||||
actors.extend([SlowCounter.remote()
|
||||
for _ in range(num_actors_at_a_time)])
|
||||
actors.extend(
|
||||
[SlowCounter.remote() for _ in range(num_actors_at_a_time)])
|
||||
# Run some methods.
|
||||
for j in range(len(actors)):
|
||||
actor = actors[j]
|
||||
for _ in range(num_function_calls_at_a_time):
|
||||
result_ids[actor].append(
|
||||
actor.inc.remote(j ** 2 * 0.000001))
|
||||
result_ids[actor].append(actor.inc.remote(j**2 * 0.000001))
|
||||
# Kill a plasma store to get rid of the cached objects and trigger
|
||||
# exit of the corresponding local scheduler. Don't kill the first
|
||||
# local scheduler since that is the one that the driver is
|
||||
|
@ -1302,18 +1331,24 @@ class ActorReconstruction(unittest.TestCase):
|
|||
for j in range(len(actors)):
|
||||
actor = actors[j]
|
||||
for _ in range(num_function_calls_at_a_time):
|
||||
result_ids[actor].append(
|
||||
actor.inc.remote(j ** 2 * 0.000001))
|
||||
result_ids[actor].append(actor.inc.remote(j**2 * 0.000001))
|
||||
|
||||
# Get the results and check that they have the correct values.
|
||||
for _, result_id_list in result_ids.items():
|
||||
self.assertEqual(ray.get(result_id_list),
|
||||
list(range(1, len(result_id_list) + 1)))
|
||||
self.assertEqual(
|
||||
ray.get(result_id_list), list(
|
||||
range(1,
|
||||
len(result_id_list) + 1)))
|
||||
|
||||
def setup_counter_actor(self, test_checkpoint=False, save_exception=False,
|
||||
def setup_counter_actor(self,
|
||||
test_checkpoint=False,
|
||||
save_exception=False,
|
||||
resume_exception=False):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_workers=0, redirect_output=True)
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_workers=0,
|
||||
redirect_output=True)
|
||||
|
||||
# Only set the checkpoint interval if we're testing with checkpointing.
|
||||
checkpoint_interval = -1
|
||||
|
@ -1371,8 +1406,7 @@ class ActorReconstruction(unittest.TestCase):
|
|||
return actor, ids
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testCheckpointing(self):
|
||||
actor, ids = self.setup_counter_actor(test_checkpoint=True)
|
||||
# Wait for the last task to finish running.
|
||||
|
@ -1397,8 +1431,7 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertLess(num_inc_calls, x)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testRemoteCheckpoint(self):
|
||||
actor, ids = self.setup_counter_actor(test_checkpoint=True)
|
||||
|
||||
|
@ -1424,8 +1457,7 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertEqual(x, 101)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testLostCheckpoint(self):
|
||||
actor, ids = self.setup_counter_actor(test_checkpoint=True)
|
||||
# Wait for the first fraction of tasks to finish running.
|
||||
|
@ -1451,11 +1483,10 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertLess(5, num_inc_calls)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testCheckpointException(self):
|
||||
actor, ids = self.setup_counter_actor(test_checkpoint=True,
|
||||
save_exception=True)
|
||||
actor, ids = self.setup_counter_actor(
|
||||
test_checkpoint=True, save_exception=True)
|
||||
# Wait for the last task to finish running.
|
||||
ray.get(ids[-1])
|
||||
|
||||
|
@ -1481,11 +1512,10 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertEqual(error[b"type"], b"checkpoint")
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testCheckpointResumeException(self):
|
||||
actor, ids = self.setup_counter_actor(test_checkpoint=True,
|
||||
resume_exception=True)
|
||||
actor, ids = self.setup_counter_actor(
|
||||
test_checkpoint=True, resume_exception=True)
|
||||
# Wait for the last task to finish running.
|
||||
ray.get(ids[-1])
|
||||
|
||||
|
@ -1527,8 +1557,9 @@ class ActorReconstruction(unittest.TestCase):
|
|||
count = ray.get(ids[-1])
|
||||
num_incs = 100
|
||||
num_iters = 10
|
||||
forks = [fork_many_incs.remote(counter, num_incs) for _ in
|
||||
range(num_iters)]
|
||||
forks = [
|
||||
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
|
||||
]
|
||||
ray.wait(forks, num_returns=len(forks))
|
||||
count += num_incs * num_iters
|
||||
|
||||
|
@ -1547,8 +1578,7 @@ class ActorReconstruction(unittest.TestCase):
|
|||
self.assertEqual(x, count + 1)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testRemoteCheckpointDistributedHandle(self):
|
||||
counter, ids = self.setup_counter_actor(test_checkpoint=True)
|
||||
|
||||
|
@ -1564,8 +1594,9 @@ class ActorReconstruction(unittest.TestCase):
|
|||
count = ray.get(ids[-1])
|
||||
num_incs = 100
|
||||
num_iters = 10
|
||||
forks = [fork_many_incs.remote(counter, num_incs) for _ in
|
||||
range(num_iters)]
|
||||
forks = [
|
||||
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
|
||||
]
|
||||
ray.wait(forks, num_returns=len(forks))
|
||||
ray.wait([counter.__ray_checkpoint__.remote()])
|
||||
count += num_incs * num_iters
|
||||
|
@ -1605,8 +1636,9 @@ class ActorReconstruction(unittest.TestCase):
|
|||
count = ray.get(ids[-1])
|
||||
num_incs = 100
|
||||
num_iters = 10
|
||||
forks = [fork_many_incs.remote(counter, num_incs) for _ in
|
||||
range(num_iters)]
|
||||
forks = [
|
||||
fork_many_incs.remote(counter, num_incs) for _ in range(num_iters)
|
||||
]
|
||||
ray.wait(forks, num_returns=len(forks))
|
||||
count += num_incs * num_iters
|
||||
|
||||
|
@ -1624,11 +1656,13 @@ class ActorReconstruction(unittest.TestCase):
|
|||
x = ray.get(counter.inc.remote())
|
||||
self.assertEqual(x, count + 1)
|
||||
|
||||
def _testNondeterministicReconstruction(self, num_forks,
|
||||
num_items_per_fork,
|
||||
num_forks_to_wait):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_workers=0, redirect_output=True)
|
||||
def _testNondeterministicReconstruction(
|
||||
self, num_forks, num_items_per_fork, num_forks_to_wait):
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_workers=0,
|
||||
redirect_output=True)
|
||||
|
||||
# Make a shared queue.
|
||||
@ray.remote
|
||||
|
@ -1668,8 +1702,9 @@ class ActorReconstruction(unittest.TestCase):
|
|||
# unique objects to push onto the shared queue.
|
||||
enqueue_tasks = []
|
||||
for fork in range(num_forks):
|
||||
enqueue_tasks.append(enqueue.remote(
|
||||
actor, [(fork, i) for i in range(num_items_per_fork)]))
|
||||
enqueue_tasks.append(
|
||||
enqueue.remote(actor,
|
||||
[(fork, i) for i in range(num_items_per_fork)]))
|
||||
# Wait for the forks to complete their tasks.
|
||||
enqueue_tasks = ray.get(enqueue_tasks)
|
||||
enqueue_tasks = [fork_ids[0] for fork_ids in enqueue_tasks]
|
||||
|
@ -1689,8 +1724,8 @@ class ActorReconstruction(unittest.TestCase):
|
|||
ray.get(enqueue_tasks)
|
||||
reconstructed_queue = ray.get(actor.read.remote())
|
||||
# Make sure the final queue has all items from all forks.
|
||||
self.assertEqual(len(reconstructed_queue), num_forks *
|
||||
num_items_per_fork)
|
||||
self.assertEqual(
|
||||
len(reconstructed_queue), num_forks * num_items_per_fork)
|
||||
# Make sure that the prefix of the final queue matches the queue from
|
||||
# the initial execution.
|
||||
self.assertEqual(queue, reconstructed_queue[:len(queue)])
|
||||
|
@ -1709,7 +1744,6 @@ class ActorReconstruction(unittest.TestCase):
|
|||
|
||||
|
||||
class DistributedActorHandles(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -1757,8 +1791,9 @@ class DistributedActorHandles(unittest.TestCase):
|
|||
# Fork num_iters times.
|
||||
num_forks = 10
|
||||
num_items_per_fork = 100
|
||||
ray.get([fork.remote(queue, i, num_items_per_fork) for i in
|
||||
range(num_forks)])
|
||||
ray.get([
|
||||
fork.remote(queue, i, num_items_per_fork) for i in range(num_forks)
|
||||
])
|
||||
items = ray.get(queue.read.remote())
|
||||
for i in range(num_forks):
|
||||
filtered_items = [item[1] for item in items if item[0] == i]
|
||||
|
@ -1812,7 +1847,6 @@ class DistributedActorHandles(unittest.TestCase):
|
|||
|
||||
|
||||
class ActorPlacementAndResources(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -1836,14 +1870,20 @@ class ActorPlacementAndResources(unittest.TestCase):
|
|||
|
||||
actor2s = [Actor2.remote() for _ in range(2)]
|
||||
results = [a.method.remote() for a in actor2s]
|
||||
ready_ids, remaining_ids = ray.wait(results, num_returns=len(results),
|
||||
timeout=1000)
|
||||
ready_ids, remaining_ids = ray.wait(
|
||||
results, num_returns=len(results), timeout=1000)
|
||||
self.assertEqual(len(ready_ids), 1)
|
||||
|
||||
def testCustomLabelPlacement(self):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_workers=0, resources=[{"CustomResource1": 2},
|
||||
{"CustomResource2": 2}])
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_workers=0,
|
||||
resources=[{
|
||||
"CustomResource1": 2
|
||||
}, {
|
||||
"CustomResource2": 2
|
||||
}])
|
||||
|
||||
@ray.remote(resources={"CustomResource1": 1})
|
||||
class ResourceActor1(object):
|
||||
|
@ -1868,7 +1908,10 @@ class ActorPlacementAndResources(unittest.TestCase):
|
|||
self.assertNotEqual(location, local_plasma)
|
||||
|
||||
def testCreatingMoreActorsThanResources(self):
|
||||
ray.init(num_workers=0, num_cpus=10, num_gpus=2,
|
||||
ray.init(
|
||||
num_workers=0,
|
||||
num_cpus=10,
|
||||
num_gpus=2,
|
||||
resources={"CustomResource1": 1})
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
|
|
|
@ -20,8 +20,9 @@ class RemoteArrayTest(unittest.TestCase):
|
|||
ray.worker.cleanup()
|
||||
|
||||
def testMethods(self):
|
||||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
|
||||
da.linalg]:
|
||||
for module in [
|
||||
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
|
||||
]:
|
||||
reload(module)
|
||||
ray.init()
|
||||
|
||||
|
@ -56,8 +57,9 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
ray.worker.cleanup()
|
||||
|
||||
def testAssemble(self):
|
||||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
|
||||
da.linalg]:
|
||||
for module in [
|
||||
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
|
||||
]:
|
||||
reload(module)
|
||||
ray.init()
|
||||
|
||||
|
@ -66,15 +68,18 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE],
|
||||
np.array([[a], [b]]))
|
||||
assert_equal(x.assemble(),
|
||||
np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]),
|
||||
np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])]))
|
||||
np.vstack([
|
||||
np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]),
|
||||
np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
]))
|
||||
|
||||
def testMethods(self):
|
||||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
|
||||
da.linalg]:
|
||||
for module in [
|
||||
ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg
|
||||
]:
|
||||
reload(module)
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_cpus=[10, 10])
|
||||
ray.worker._init(
|
||||
start_ray_local=True, num_local_schedulers=2, num_cpus=[10, 10])
|
||||
|
||||
x = da.zeros.remote([9, 25, 51], "float")
|
||||
assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51]))
|
||||
|
@ -84,20 +89,22 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
|
||||
x = da.random.normal.remote([11, 25, 49])
|
||||
y = da.copy.remote(x)
|
||||
assert_equal(ray.get(da.assemble.remote(x)),
|
||||
ray.get(da.assemble.remote(y)))
|
||||
assert_equal(
|
||||
ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y)))
|
||||
|
||||
x = da.eye.remote(25, dtype_name="float")
|
||||
assert_equal(ray.get(da.assemble.remote(x)), np.eye(25))
|
||||
|
||||
x = da.random.normal.remote([25, 49])
|
||||
y = da.triu.remote(x)
|
||||
assert_equal(ray.get(da.assemble.remote(y)),
|
||||
assert_equal(
|
||||
ray.get(da.assemble.remote(y)),
|
||||
np.triu(ray.get(da.assemble.remote(x))))
|
||||
|
||||
x = da.random.normal.remote([25, 49])
|
||||
y = da.tril.remote(x)
|
||||
assert_equal(ray.get(da.assemble.remote(y)),
|
||||
assert_equal(
|
||||
ray.get(da.assemble.remote(y)),
|
||||
np.tril(ray.get(da.assemble.remote(x))))
|
||||
|
||||
x = da.random.normal.remote([25, 49])
|
||||
|
@ -113,31 +120,31 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
x = da.random.normal.remote([23, 42])
|
||||
y = da.random.normal.remote([23, 42])
|
||||
z = da.add.remote(x, y)
|
||||
assert_almost_equal(ray.get(da.assemble.remote(z)),
|
||||
ray.get(da.assemble.remote(x)) +
|
||||
ray.get(da.assemble.remote(y)))
|
||||
assert_almost_equal(
|
||||
ray.get(da.assemble.remote(z)),
|
||||
ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y)))
|
||||
|
||||
# test subtract
|
||||
x = da.random.normal.remote([33, 40])
|
||||
y = da.random.normal.remote([33, 40])
|
||||
z = da.subtract.remote(x, y)
|
||||
assert_almost_equal(ray.get(da.assemble.remote(z)),
|
||||
ray.get(da.assemble.remote(x)) -
|
||||
ray.get(da.assemble.remote(y)))
|
||||
assert_almost_equal(
|
||||
ray.get(da.assemble.remote(z)),
|
||||
ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y)))
|
||||
|
||||
# test transpose
|
||||
x = da.random.normal.remote([234, 432])
|
||||
y = da.transpose.remote(x)
|
||||
assert_equal(ray.get(da.assemble.remote(x)).T,
|
||||
ray.get(da.assemble.remote(y)))
|
||||
assert_equal(
|
||||
ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y)))
|
||||
|
||||
# test numpy_to_dist
|
||||
x = da.random.normal.remote([23, 45])
|
||||
y = da.assemble.remote(x)
|
||||
z = da.numpy_to_dist.remote(y)
|
||||
w = da.assemble.remote(z)
|
||||
assert_equal(ray.get(da.assemble.remote(x)),
|
||||
ray.get(da.assemble.remote(z)))
|
||||
assert_equal(
|
||||
ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z)))
|
||||
assert_equal(ray.get(y), ray.get(w))
|
||||
|
||||
# test da.tsqr
|
||||
|
@ -157,8 +164,8 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
|
||||
# test da.linalg.modified_lu
|
||||
def test_modified_lu(d1, d2):
|
||||
print("testing dist_modified_lu with d1 = " + str(d1) +
|
||||
", d2 = " + str(d2))
|
||||
print("testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " +
|
||||
str(d2))
|
||||
assert d1 >= d2
|
||||
m = ra.random.normal.remote([d1, d2])
|
||||
q, r = ra.linalg.qr.remote(m)
|
||||
|
@ -178,8 +185,8 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
# Check that l is lower triangular.
|
||||
assert_equal(np.tril(l_val), l_val)
|
||||
|
||||
for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7),
|
||||
(20, 10)]:
|
||||
for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20,
|
||||
10)]:
|
||||
test_modified_lu(d1, d2)
|
||||
|
||||
# test dist_tsqr_hr
|
||||
|
|
|
@ -56,7 +56,8 @@ class MockProvider(NodeProvider):
|
|||
raise Exception("oops")
|
||||
return [
|
||||
n.node_id for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters) and n.state != "terminated"]
|
||||
if n.matches(tag_filters) and n.state != "terminated"
|
||||
]
|
||||
|
||||
def is_running(self, node_id):
|
||||
return self.mock_nodes[node_id].state == "running"
|
||||
|
@ -101,7 +102,6 @@ SMALL_CLUSTER = {
|
|||
"docker": {
|
||||
"image": "example",
|
||||
"container_name": "mock",
|
||||
|
||||
},
|
||||
"auth": {
|
||||
"ssh_user": "ubuntu",
|
||||
|
@ -269,8 +269,11 @@ class AutoscalingTest(unittest.TestCase):
|
|||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_concurrent_launches=5,
|
||||
max_failures=0, update_interval_s=0)
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_concurrent_launches=5,
|
||||
max_failures=0,
|
||||
update_interval_s=0)
|
||||
self.assertEqual(len(self.provider.nodes({})), 0)
|
||||
autoscaler.update()
|
||||
self.assertEqual(len(self.provider.nodes({})), 2)
|
||||
|
@ -295,8 +298,11 @@ class AutoscalingTest(unittest.TestCase):
|
|||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_concurrent_launches=5,
|
||||
max_failures=0, update_interval_s=10)
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_concurrent_launches=5,
|
||||
max_failures=0,
|
||||
update_interval_s=10)
|
||||
autoscaler.update()
|
||||
self.assertEqual(len(self.provider.nodes({})), 2)
|
||||
new_config = SMALL_CLUSTER.copy()
|
||||
|
@ -328,8 +334,11 @@ class AutoscalingTest(unittest.TestCase):
|
|||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_concurrent_launches=10,
|
||||
max_failures=0, update_interval_s=0)
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_concurrent_launches=10,
|
||||
max_failures=0,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
|
||||
# Write a corrupted config
|
||||
|
@ -383,16 +392,22 @@ class AutoscalingTest(unittest.TestCase):
|
|||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
|
||||
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
verbose_updates=True,
|
||||
node_updater_cls=NodeUpdaterThread,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.assertEqual(len(self.provider.nodes({})), 2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
assert len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "Uninitialized"})) == 2
|
||||
assert len(
|
||||
self.provider.nodes({
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized"
|
||||
})) == 2
|
||||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
|
@ -403,16 +418,22 @@ class AutoscalingTest(unittest.TestCase):
|
|||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner(fail_cmds=["cmd1"])
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
|
||||
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
verbose_updates=True,
|
||||
node_updater_cls=NodeUpdaterThread,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.assertEqual(len(self.provider.nodes({})), 2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
assert len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "Uninitialized"})) == 2
|
||||
assert len(
|
||||
self.provider.nodes({
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized"
|
||||
})) == 2
|
||||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
|
@ -423,8 +444,12 @@ class AutoscalingTest(unittest.TestCase):
|
|||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_failures=0, process_runner=runner,
|
||||
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
verbose_updates=True,
|
||||
node_updater_cls=NodeUpdaterThread,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
|
@ -490,8 +515,12 @@ class AutoscalingTest(unittest.TestCase):
|
|||
runner = MockProcessRunner()
|
||||
lm = LoadMetrics()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, lm, max_failures=0, process_runner=runner,
|
||||
verbose_updates=True, node_updater_cls=NodeUpdaterThread,
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
verbose_updates=True,
|
||||
node_updater_cls=NodeUpdaterThread,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
for node in self.provider.mock_nodes.values():
|
||||
|
|
|
@ -11,7 +11,6 @@ import pyarrow as pa
|
|||
|
||||
|
||||
class ComponentFailureTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -24,7 +23,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
def f():
|
||||
ray.worker.global_worker.plasma_client.get(obj_id)
|
||||
|
||||
ray.worker._init(num_workers=1,
|
||||
ray.worker._init(
|
||||
num_workers=1,
|
||||
driver_mode=ray.SILENT_MODE,
|
||||
start_workers_from_local_scheduler=False,
|
||||
start_ray_local=True,
|
||||
|
@ -35,8 +35,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
|
||||
# Kill the worker.
|
||||
time.sleep(1)
|
||||
(ray.services
|
||||
.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate())
|
||||
(ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0]
|
||||
.terminate())
|
||||
time.sleep(0.1)
|
||||
|
||||
# Seal the object so the store attempts to notify the worker that the
|
||||
|
@ -47,7 +47,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
|
||||
# Make sure that nothing has died.
|
||||
self.assertTrue(ray.services.all_processes_alive(
|
||||
self.assertTrue(
|
||||
ray.services.all_processes_alive(
|
||||
exclude=[ray.services.PROCESS_TYPE_WORKER]))
|
||||
|
||||
# This test checks that when a worker dies in the middle of a wait, the
|
||||
|
@ -59,7 +60,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
def f():
|
||||
ray.worker.global_worker.plasma_client.wait([obj_id])
|
||||
|
||||
ray.worker._init(num_workers=1,
|
||||
ray.worker._init(
|
||||
num_workers=1,
|
||||
driver_mode=ray.SILENT_MODE,
|
||||
start_workers_from_local_scheduler=False,
|
||||
start_ray_local=True,
|
||||
|
@ -70,8 +72,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
|
||||
# Kill the worker.
|
||||
time.sleep(1)
|
||||
(ray.services
|
||||
.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate())
|
||||
(ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0]
|
||||
.terminate())
|
||||
time.sleep(0.1)
|
||||
|
||||
# Seal the object so the store attempts to notify the worker that the
|
||||
|
@ -82,7 +84,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
|
||||
# Make sure that nothing has died.
|
||||
self.assertTrue(ray.services.all_processes_alive(
|
||||
self.assertTrue(
|
||||
ray.services.all_processes_alive(
|
||||
exclude=[ray.services.PROCESS_TYPE_WORKER]))
|
||||
|
||||
def _testWorkerFailed(self, num_local_schedulers):
|
||||
|
@ -92,8 +95,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
return x
|
||||
|
||||
num_initial_workers = 4
|
||||
ray.worker._init(num_workers=(num_initial_workers *
|
||||
num_local_schedulers),
|
||||
ray.worker._init(
|
||||
num_workers=(num_initial_workers * num_local_schedulers),
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
start_workers_from_local_scheduler=False,
|
||||
start_ray_local=True,
|
||||
|
@ -101,14 +104,16 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
redirect_output=True)
|
||||
# Submit more tasks than there are workers so that all workers and
|
||||
# cores are utilized.
|
||||
object_ids = [f.remote(i) for i
|
||||
in range(num_initial_workers * num_local_schedulers)]
|
||||
object_ids = [
|
||||
f.remote(i)
|
||||
for i in range(num_initial_workers * num_local_schedulers)
|
||||
]
|
||||
object_ids += [f.remote(object_id) for object_id in object_ids]
|
||||
# Allow the tasks some time to begin executing.
|
||||
time.sleep(0.1)
|
||||
# Kill the workers as the tasks execute.
|
||||
for worker in (ray.services
|
||||
.all_processes[ray.services.PROCESS_TYPE_WORKER]):
|
||||
for worker in (
|
||||
ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]):
|
||||
worker.terminate()
|
||||
time.sleep(0.1)
|
||||
# Make sure that we can still get the objects after the executing tasks
|
||||
|
@ -123,6 +128,7 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
|
||||
def _testComponentFailed(self, component_type):
|
||||
"""Kill a component on all worker nodes and check workload succeeds."""
|
||||
|
||||
@ray.remote
|
||||
def f(x, j):
|
||||
time.sleep(0.2)
|
||||
|
@ -140,9 +146,10 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
|
||||
# Submit more tasks than there are workers so that all workers and
|
||||
# cores are utilized.
|
||||
object_ids = [f.remote(i, 0) for i
|
||||
in range(num_workers_per_scheduler *
|
||||
num_local_schedulers)]
|
||||
object_ids = [
|
||||
f.remote(i, 0)
|
||||
for i in range(num_workers_per_scheduler * num_local_schedulers)
|
||||
]
|
||||
object_ids += [f.remote(object_id, 1) for object_id in object_ids]
|
||||
object_ids += [f.remote(object_id, 2) for object_id in object_ids]
|
||||
|
||||
|
@ -162,8 +169,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
# Make sure that we can still get the objects after the executing tasks
|
||||
# died.
|
||||
results = ray.get(object_ids)
|
||||
expected_results = 4 * list(range(
|
||||
num_workers_per_scheduler * num_local_schedulers))
|
||||
expected_results = 4 * list(
|
||||
range(num_workers_per_scheduler * num_local_schedulers))
|
||||
self.assertEqual(results, expected_results)
|
||||
|
||||
def check_components_alive(self, component_type, check_component_alive):
|
||||
|
@ -182,8 +189,7 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
self.assertTrue(not component.poll() is None)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testLocalSchedulerFailed(self):
|
||||
# Kill all local schedulers on worker nodes.
|
||||
self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER)
|
||||
|
@ -198,8 +204,7 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
False)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testPlasmaManagerFailed(self):
|
||||
# Kill all plasma managers on worker nodes.
|
||||
self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER)
|
||||
|
@ -214,8 +219,7 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
False)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testPlasmaStoreFailed(self):
|
||||
# Kill all plasma stores on worker nodes.
|
||||
self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE)
|
||||
|
@ -235,7 +239,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
|
||||
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]
|
||||
]
|
||||
|
||||
# Kill all the components sequentially.
|
||||
for process in processes:
|
||||
|
@ -253,7 +258,8 @@ class ComponentFailureTest(unittest.TestCase):
|
|||
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
|
||||
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
|
||||
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]
|
||||
]
|
||||
|
||||
# Kill all the components in parallel.
|
||||
for process in processes:
|
||||
|
|
|
@ -9,8 +9,7 @@ import unittest
|
|||
import ray
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
@unittest.skipIf(not os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Tests functionality of the new GCS.")
|
||||
class CredisTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -22,8 +21,8 @@ class CredisTest(unittest.TestCase):
|
|||
def test_credis_started(self):
|
||||
assert "credis_address" in self.config
|
||||
credis_address, credis_port = self.config["credis_address"].split(":")
|
||||
credis_client = redis.StrictRedis(host=credis_address,
|
||||
port=credis_port)
|
||||
credis_client = redis.StrictRedis(
|
||||
host=credis_address, port=credis_port)
|
||||
assert credis_client.ping() is True
|
||||
|
||||
redis_client = ray.worker.global_state.redis_client
|
||||
|
|
|
@ -108,6 +108,7 @@ def temporary_helper_function():
|
|||
def f(worker):
|
||||
if ray.worker.global_worker.mode == ray.WORKER_MODE:
|
||||
raise Exception("Function to run failed.")
|
||||
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
wait_for_errors(b"function_to_run", 2)
|
||||
# Check that the error message is in the task info.
|
||||
|
@ -348,12 +349,14 @@ class PutErrorTest(unittest.TestCase):
|
|||
ray.worker.cleanup()
|
||||
|
||||
def testPutError1(self):
|
||||
store_size = 10 ** 6
|
||||
ray.worker._init(start_ray_local=True, driver_mode=ray.SILENT_MODE,
|
||||
store_size = 10**6
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
driver_mode=ray.SILENT_MODE,
|
||||
object_store_memory=store_size)
|
||||
|
||||
num_objects = 3
|
||||
object_size = 4 * 10 ** 5
|
||||
object_size = 4 * 10**5
|
||||
|
||||
# Define a task with a single dependency, a numpy array, that returns
|
||||
# another array.
|
||||
|
@ -369,8 +372,9 @@ class PutErrorTest(unittest.TestCase):
|
|||
# on the one before it. The result of the first task should get
|
||||
# evicted.
|
||||
args = []
|
||||
arg = single_dependency.remote(0, np.zeros(object_size,
|
||||
dtype=np.uint8))
|
||||
arg = single_dependency.remote(0,
|
||||
np.zeros(
|
||||
object_size, dtype=np.uint8))
|
||||
for i in range(num_objects):
|
||||
arg = single_dependency.remote(i, arg)
|
||||
args.append(arg)
|
||||
|
@ -393,12 +397,14 @@ class PutErrorTest(unittest.TestCase):
|
|||
|
||||
def testPutError2(self):
|
||||
# This is the same as the previous test, but it calls ray.put directly.
|
||||
store_size = 10 ** 6
|
||||
ray.worker._init(start_ray_local=True, driver_mode=ray.SILENT_MODE,
|
||||
store_size = 10**6
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
driver_mode=ray.SILENT_MODE,
|
||||
object_store_memory=store_size)
|
||||
|
||||
num_objects = 3
|
||||
object_size = 4 * 10 ** 5
|
||||
object_size = 4 * 10**5
|
||||
|
||||
# Define a task with a single dependency, a numpy array, that returns
|
||||
# another array.
|
||||
|
|
|
@ -58,6 +58,7 @@ class DockerRunner(object):
|
|||
head_container_ip: The IP address of the docker container that runs the
|
||||
head node.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the DockerRunner."""
|
||||
self.head_container_id = None
|
||||
|
@ -91,11 +92,14 @@ class DockerRunner(object):
|
|||
Returns:
|
||||
The IP address of the container.
|
||||
"""
|
||||
proc = subprocess.Popen(["docker", "inspect",
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"docker", "inspect",
|
||||
"--format={{.NetworkSettings.Networks.bridge"
|
||||
".IPAddress}}",
|
||||
container_id],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
".IPAddress}}", container_id
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
p = re.compile("([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})")
|
||||
m = p.match(stdout_data)
|
||||
|
@ -110,23 +114,23 @@ class DockerRunner(object):
|
|||
"""Start the Ray head node inside a docker container."""
|
||||
mem_arg = ["--memory=" + mem_size] if mem_size else []
|
||||
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
|
||||
volume_arg = (["-v",
|
||||
"{}:{}".format(os.path.dirname(
|
||||
os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")]
|
||||
if development_mode else [])
|
||||
volume_arg = ([
|
||||
"-v", "{}:{}".format(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")
|
||||
] if development_mode else [])
|
||||
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
|
||||
[docker_image, "ray", "start", "--head", "--block",
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [
|
||||
docker_image, "ray", "start", "--head", "--block",
|
||||
"--redis-port=6379",
|
||||
"--num-redis-shards={}".format(num_redis_shards),
|
||||
"--num-cpus={}".format(num_cpus),
|
||||
"--num-gpus={}".format(num_gpus),
|
||||
"--no-ui"])
|
||||
"--num-cpus={}".format(num_cpus), "--num-gpus={}".format(num_gpus),
|
||||
"--no-ui"
|
||||
])
|
||||
print("Starting head node with command:{}".format(command))
|
||||
|
||||
proc = subprocess.Popen(command,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
proc = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
container_id = self._get_container_id(stdout_data)
|
||||
if container_id is None:
|
||||
|
@ -139,29 +143,34 @@ class DockerRunner(object):
|
|||
"""Start a Ray worker node inside a docker container."""
|
||||
mem_arg = ["--memory=" + mem_size] if mem_size else []
|
||||
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
|
||||
volume_arg = (["-v",
|
||||
"{}:{}".format(os.path.dirname(
|
||||
os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")]
|
||||
if development_mode else [])
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
|
||||
["--shm-size=" + shm_size, docker_image,
|
||||
"ray", "start", "--block",
|
||||
volume_arg = ([
|
||||
"-v", "{}:{}".format(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"/ray/test/jenkins_tests")
|
||||
] if development_mode else [])
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [
|
||||
"--shm-size=" + shm_size, docker_image, "ray", "start", "--block",
|
||||
"--redis-address={:s}:6379".format(self.head_container_ip),
|
||||
"--num-cpus={}".format(num_cpus),
|
||||
"--num-gpus={}".format(num_gpus)])
|
||||
"--num-cpus={}".format(num_cpus), "--num-gpus={}".format(num_gpus)
|
||||
])
|
||||
print("Starting worker node with command:{}".format(command))
|
||||
proc = subprocess.Popen(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
proc = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
container_id = self._get_container_id(stdout_data)
|
||||
if container_id is None:
|
||||
raise RuntimeError("Failed to find container id")
|
||||
self.worker_container_ids.append(container_id)
|
||||
|
||||
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
|
||||
num_nodes=None, num_redis_shards=1, num_cpus=None,
|
||||
num_gpus=None, development_mode=None):
|
||||
def start_ray(self,
|
||||
docker_image=None,
|
||||
mem_size=None,
|
||||
shm_size=None,
|
||||
num_nodes=None,
|
||||
num_redis_shards=1,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
development_mode=None):
|
||||
"""Start a Ray cluster within docker.
|
||||
|
||||
This starts one docker container running the head node and
|
||||
|
@ -200,24 +209,31 @@ class DockerRunner(object):
|
|||
|
||||
def _stop_node(self, container_id):
|
||||
"""Stop a node in the Ray cluster."""
|
||||
proc = subprocess.Popen(["docker", "kill", container_id],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
proc = subprocess.Popen(
|
||||
["docker", "kill", container_id],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
stopped_container_id = self._get_container_id(stdout_data)
|
||||
if not container_id == stopped_container_id:
|
||||
raise Exception("Failed to stop container {}."
|
||||
.format(container_id))
|
||||
|
||||
proc = subprocess.Popen(["docker", "rm", "-f", container_id],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
proc = subprocess.Popen(
|
||||
["docker", "rm", "-f", container_id],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout_data, _ = wait_for_output(proc)
|
||||
removed_container_id = self._get_container_id(stdout_data)
|
||||
if not container_id == removed_container_id:
|
||||
raise Exception("Failed to remove container {}."
|
||||
.format(container_id))
|
||||
|
||||
print("stop_node", {"container_id": container_id,
|
||||
"is_head": container_id == self.head_container_id})
|
||||
print(
|
||||
"stop_node", {
|
||||
"container_id": container_id,
|
||||
"is_head": container_id == self.head_container_id
|
||||
})
|
||||
|
||||
def stop_ray(self):
|
||||
"""Stop the Ray cluster."""
|
||||
|
@ -236,7 +252,10 @@ class DockerRunner(object):
|
|||
|
||||
return success
|
||||
|
||||
def run_test(self, test_script, num_drivers, driver_locations=None,
|
||||
def run_test(self,
|
||||
test_script,
|
||||
num_drivers,
|
||||
driver_locations=None,
|
||||
timeout_seconds=600):
|
||||
"""Run a test script.
|
||||
|
||||
|
@ -258,11 +277,13 @@ class DockerRunner(object):
|
|||
Raises:
|
||||
Exception: An exception is raised if the timeout expires.
|
||||
"""
|
||||
all_container_ids = ([self.head_container_id] +
|
||||
self.worker_container_ids)
|
||||
all_container_ids = (
|
||||
[self.head_container_id] + self.worker_container_ids)
|
||||
if driver_locations is None:
|
||||
driver_locations = [np.random.randint(0, len(all_container_ids))
|
||||
for _ in range(num_drivers)]
|
||||
driver_locations = [
|
||||
np.random.randint(0, len(all_container_ids))
|
||||
for _ in range(num_drivers)
|
||||
]
|
||||
|
||||
# Define a signal handler and set an alarm to go off in
|
||||
# timeout_seconds.
|
||||
|
@ -278,13 +299,15 @@ class DockerRunner(object):
|
|||
for i in range(len(driver_locations)):
|
||||
# Get the container ID to run the ith driver in.
|
||||
container_id = all_container_ids[driver_locations[i]]
|
||||
command = ["docker", "exec", container_id, "/bin/bash", "-c",
|
||||
command = [
|
||||
"docker", "exec", container_id, "/bin/bash", "-c",
|
||||
("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python "
|
||||
"{}".format(self.head_container_ip, i, test_script))]
|
||||
"{}".format(self.head_container_ip, i, test_script))
|
||||
]
|
||||
print("Starting driver with command {}.".format(test_script))
|
||||
# Start the driver.
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
driver_processes.append(p)
|
||||
|
||||
# Wait for the drivers to finish.
|
||||
|
@ -295,8 +318,10 @@ class DockerRunner(object):
|
|||
print(stdout_data)
|
||||
print("STDERR:")
|
||||
print(stderr_data)
|
||||
results.append({"success": p.returncode == 0,
|
||||
"return_code": p.returncode})
|
||||
results.append({
|
||||
"success": p.returncode == 0,
|
||||
"return_code": p.returncode
|
||||
})
|
||||
|
||||
# Disable the alarm.
|
||||
signal.alarm(0)
|
||||
|
@ -307,28 +332,42 @@ class DockerRunner(object):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run multinode tests in Docker.")
|
||||
parser.add_argument("--docker-image", default="ray-project/deploy",
|
||||
help="docker image")
|
||||
parser.add_argument(
|
||||
"--docker-image", default="ray-project/deploy", help="docker image")
|
||||
parser.add_argument("--mem-size", help="memory size")
|
||||
parser.add_argument("--shm-size", default="1G", help="shared memory size")
|
||||
parser.add_argument("--num-nodes", default=1, type=int,
|
||||
parser.add_argument(
|
||||
"--num-nodes",
|
||||
default=1,
|
||||
type=int,
|
||||
help="number of nodes to use in the cluster")
|
||||
parser.add_argument("--num-redis-shards", default=1, type=int,
|
||||
parser.add_argument(
|
||||
"--num-redis-shards",
|
||||
default=1,
|
||||
type=int,
|
||||
help=("the number of Redis shards to start on the "
|
||||
"head node"))
|
||||
parser.add_argument("--num-cpus", type=str,
|
||||
parser.add_argument(
|
||||
"--num-cpus",
|
||||
type=str,
|
||||
help=("a comma separated list of values representing "
|
||||
"the number of CPUs to start each node with"))
|
||||
parser.add_argument("--num-gpus", type=str,
|
||||
parser.add_argument(
|
||||
"--num-gpus",
|
||||
type=str,
|
||||
help=("a comma separated list of values representing "
|
||||
"the number of GPUs to start each node with"))
|
||||
parser.add_argument("--num-drivers", default=1, type=int,
|
||||
help="number of drivers to run")
|
||||
parser.add_argument("--driver-locations", type=str,
|
||||
parser.add_argument(
|
||||
"--num-drivers", default=1, type=int, help="number of drivers to run")
|
||||
parser.add_argument(
|
||||
"--driver-locations",
|
||||
type=str,
|
||||
help=("a comma separated list of indices of the "
|
||||
"containers to run the drivers in"))
|
||||
parser.add_argument("--test-script", required=True, help="test script")
|
||||
parser.add_argument("--development-mode", action="store_true",
|
||||
parser.add_argument(
|
||||
"--development-mode",
|
||||
action="store_true",
|
||||
help="use local copies of the test scripts")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -340,17 +379,23 @@ if __name__ == "__main__":
|
|||
if args.num_gpus is not None else num_nodes * [0])
|
||||
|
||||
# Parse the driver locations.
|
||||
driver_locations = (None if args.driver_locations is None
|
||||
else [int(i) for i
|
||||
in args.driver_locations.split(",")])
|
||||
driver_locations = (None if args.driver_locations is None else
|
||||
[int(i) for i in args.driver_locations.split(",")])
|
||||
|
||||
d = DockerRunner()
|
||||
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
|
||||
shm_size=args.shm_size, num_nodes=num_nodes,
|
||||
num_redis_shards=args.num_redis_shards, num_cpus=num_cpus,
|
||||
num_gpus=num_gpus, development_mode=args.development_mode)
|
||||
d.start_ray(
|
||||
docker_image=args.docker_image,
|
||||
mem_size=args.mem_size,
|
||||
shm_size=args.shm_size,
|
||||
num_nodes=num_nodes,
|
||||
num_redis_shards=args.num_redis_shards,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
development_mode=args.development_mode)
|
||||
try:
|
||||
run_results = d.run_test(args.test_script, args.num_drivers,
|
||||
run_results = d.run_test(
|
||||
args.test_script,
|
||||
args.num_drivers,
|
||||
driver_locations=driver_locations)
|
||||
finally:
|
||||
successfully_stopped = d.stop_ray()
|
||||
|
|
|
@ -6,19 +6,17 @@ import numpy as np
|
|||
|
||||
import ray
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_workers=0)
|
||||
|
||||
A = np.ones(2 ** 31 + 1, dtype="int8")
|
||||
A = np.ones(2**31 + 1, dtype="int8")
|
||||
a = ray.put(A)
|
||||
assert np.sum(ray.get(a)) == np.sum(A)
|
||||
del A
|
||||
del a
|
||||
print("Successfully put A.")
|
||||
|
||||
B = {"hello": np.zeros(2 ** 30 + 1),
|
||||
"world": np.ones(2 ** 30 + 1)}
|
||||
B = {"hello": np.zeros(2**30 + 1), "world": np.ones(2**30 + 1)}
|
||||
b = ray.put(B)
|
||||
assert np.sum(ray.get(b)["hello"]) == np.sum(B["hello"])
|
||||
assert np.sum(ray.get(b)["world"]) == np.sum(B["world"])
|
||||
|
@ -26,7 +24,7 @@ if __name__ == "__main__":
|
|||
del b
|
||||
print("Successfully put B.")
|
||||
|
||||
C = [np.ones(2 ** 30 + 1), 42.0 * np.ones(2 ** 30 + 1)]
|
||||
C = [np.ones(2**30 + 1), 42.0 * np.ones(2**30 + 1)]
|
||||
c = ray.put(C)
|
||||
assert np.sum(ray.get(c)[0]) == np.sum(C[0])
|
||||
assert np.sum(ray.get(c)[1]) == np.sum(C[1])
|
||||
|
|
|
@ -6,8 +6,7 @@ import os
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.test.test_utils import (_wait_for_nodes_to_join,
|
||||
_broadcast_event,
|
||||
from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event,
|
||||
_wait_for_event)
|
||||
|
||||
# This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for
|
||||
|
|
|
@ -6,10 +6,8 @@ import os
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.test.test_utils import (_wait_for_nodes_to_join,
|
||||
_broadcast_event,
|
||||
_wait_for_event,
|
||||
wait_for_pid_to_exit)
|
||||
from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event,
|
||||
_wait_for_event, wait_for_pid_to_exit)
|
||||
|
||||
# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a
|
||||
# total of 10 GPUs. It should be run with 7 drivers. Drivers 2 through 6 must
|
||||
|
@ -28,7 +26,8 @@ def remote_function_event_name(driver_index, task_index):
|
|||
|
||||
@ray.remote
|
||||
def long_running_task(driver_index, task_index, redis_address):
|
||||
_broadcast_event(remote_function_event_name(driver_index, task_index),
|
||||
_broadcast_event(
|
||||
remote_function_event_name(driver_index, task_index),
|
||||
redis_address,
|
||||
data=(ray.services.get_node_ip_address(), os.getpid()))
|
||||
# Loop forever.
|
||||
|
@ -42,10 +41,10 @@ num_long_running_tasks_per_driver = 2
|
|||
@ray.remote
|
||||
class Actor0(object):
|
||||
def __init__(self, driver_index, actor_index, redis_address):
|
||||
_broadcast_event(actor_event_name(driver_index, actor_index),
|
||||
_broadcast_event(
|
||||
actor_event_name(driver_index, actor_index),
|
||||
redis_address,
|
||||
data=(ray.services.get_node_ip_address(),
|
||||
os.getpid()))
|
||||
data=(ray.services.get_node_ip_address(), os.getpid()))
|
||||
assert len(ray.get_gpu_ids()) == 0
|
||||
|
||||
def check_ids(self):
|
||||
|
@ -60,10 +59,10 @@ class Actor0(object):
|
|||
@ray.remote(num_gpus=1)
|
||||
class Actor1(object):
|
||||
def __init__(self, driver_index, actor_index, redis_address):
|
||||
_broadcast_event(actor_event_name(driver_index, actor_index),
|
||||
_broadcast_event(
|
||||
actor_event_name(driver_index, actor_index),
|
||||
redis_address,
|
||||
data=(ray.services.get_node_ip_address(),
|
||||
os.getpid()))
|
||||
data=(ray.services.get_node_ip_address(), os.getpid()))
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
def check_ids(self):
|
||||
|
@ -78,10 +77,10 @@ class Actor1(object):
|
|||
@ray.remote(num_gpus=2)
|
||||
class Actor2(object):
|
||||
def __init__(self, driver_index, actor_index, redis_address):
|
||||
_broadcast_event(actor_event_name(driver_index, actor_index),
|
||||
_broadcast_event(
|
||||
actor_event_name(driver_index, actor_index),
|
||||
redis_address,
|
||||
data=(ray.services.get_node_ip_address(),
|
||||
os.getpid()))
|
||||
data=(ray.services.get_node_ip_address(), os.getpid()))
|
||||
assert len(ray.get_gpu_ids()) == 2
|
||||
|
||||
def check_ids(self):
|
||||
|
@ -110,11 +109,13 @@ def driver_0(redis_address, driver_index):
|
|||
long_running_task.remote(driver_index, i, redis_address)
|
||||
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = [Actor1.remote(driver_index, i, redis_address)
|
||||
for i in range(5)]
|
||||
actors_one_gpu = [
|
||||
Actor1.remote(driver_index, i, redis_address) for i in range(5)
|
||||
]
|
||||
# Create some actors that don't require any GPUs.
|
||||
actors_no_gpus = [Actor0.remote(driver_index, 5 + i, redis_address)
|
||||
for i in range(5)]
|
||||
actors_no_gpus = [
|
||||
Actor0.remote(driver_index, 5 + i, redis_address) for i in range(5)
|
||||
]
|
||||
|
||||
for _ in range(1000):
|
||||
ray.get([actor.check_ids.remote() for actor in actors_one_gpu])
|
||||
|
@ -145,14 +146,17 @@ def driver_1(redis_address, driver_index):
|
|||
long_running_task.remote(driver_index, i, redis_address)
|
||||
|
||||
# Create an actor that requires two GPUs.
|
||||
actors_two_gpus = [Actor2.remote(driver_index, i, redis_address)
|
||||
for i in range(1)]
|
||||
actors_two_gpus = [
|
||||
Actor2.remote(driver_index, i, redis_address) for i in range(1)
|
||||
]
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = [Actor1.remote(driver_index, 1 + i, redis_address)
|
||||
for i in range(3)]
|
||||
actors_one_gpu = [
|
||||
Actor1.remote(driver_index, 1 + i, redis_address) for i in range(3)
|
||||
]
|
||||
# Create some actors that don't require any GPUs.
|
||||
actors_no_gpus = [Actor0.remote(driver_index, 1 + 3 + i, redis_address)
|
||||
for i in range(5)]
|
||||
actors_no_gpus = [
|
||||
Actor0.remote(driver_index, 1 + 3 + i, redis_address) for i in range(5)
|
||||
]
|
||||
|
||||
for _ in range(1000):
|
||||
ray.get([actor.check_ids.remote() for actor in actors_two_gpus])
|
||||
|
@ -179,8 +183,9 @@ def cleanup_driver(redis_address, driver_index):
|
|||
# We go ahead and create some actors that don't require any GPUs. We
|
||||
# don't need to wait for the other drivers to finish. We call methods
|
||||
# on these actors later to make sure they haven't been killed.
|
||||
actors_no_gpus = [Actor0.remote(driver_index, i, redis_address)
|
||||
for i in range(10)]
|
||||
actors_no_gpus = [
|
||||
Actor0.remote(driver_index, i, redis_address) for i in range(10)
|
||||
]
|
||||
|
||||
_wait_for_event("DRIVER_0_DONE", redis_address)
|
||||
_wait_for_event("DRIVER_1_DONE", redis_address)
|
||||
|
@ -206,13 +211,13 @@ def cleanup_driver(redis_address, driver_index):
|
|||
# Create some actors that require two GPUs.
|
||||
actors_two_gpus = []
|
||||
for i in range(3):
|
||||
actors_two_gpus.append(try_to_create_actor(Actor2, driver_index,
|
||||
10 + i))
|
||||
actors_two_gpus.append(
|
||||
try_to_create_actor(Actor2, driver_index, 10 + i))
|
||||
# Create some actors that require one GPU.
|
||||
actors_one_gpu = []
|
||||
for i in range(4):
|
||||
actors_one_gpu.append(try_to_create_actor(Actor1, driver_index,
|
||||
10 + 3 + i))
|
||||
actors_one_gpu.append(
|
||||
try_to_create_actor(Actor1, driver_index, 10 + 3 + i))
|
||||
|
||||
removed_workers = 0
|
||||
|
||||
|
@ -233,14 +238,14 @@ def cleanup_driver(redis_address, driver_index):
|
|||
# Make sure that the PIDs for the actors from driver 0 and driver 1 have
|
||||
# been killed.
|
||||
for i in range(10):
|
||||
node_ip_address, pid = _wait_for_event(actor_event_name(0, i),
|
||||
redis_address)
|
||||
node_ip_address, pid = _wait_for_event(
|
||||
actor_event_name(0, i), redis_address)
|
||||
if node_ip_address == ray.services.get_node_ip_address():
|
||||
wait_for_pid_to_exit(pid)
|
||||
removed_workers += 1
|
||||
for i in range(9):
|
||||
node_ip_address, pid = _wait_for_event(actor_event_name(1, i),
|
||||
redis_address)
|
||||
node_ip_address, pid = _wait_for_event(
|
||||
actor_event_name(1, i), redis_address)
|
||||
if node_ip_address == ray.services.get_node_ip_address():
|
||||
wait_for_pid_to_exit(pid)
|
||||
removed_workers += 1
|
||||
|
|
|
@ -25,8 +25,9 @@ if __name__ == "__main__":
|
|||
for i in range(num_attempts):
|
||||
ip_addresses = ray.get([f.remote() for i in range(1000)])
|
||||
distinct_addresses = set(ip_addresses)
|
||||
counts = [ip_addresses.count(address) for address
|
||||
in distinct_addresses]
|
||||
counts = [
|
||||
ip_addresses.count(address) for address in distinct_addresses
|
||||
]
|
||||
print("Counts are {}".format(counts))
|
||||
if len(counts) == 5:
|
||||
break
|
||||
|
|
|
@ -25,19 +25,17 @@ def run_string_as_driver(driver_script):
|
|||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(driver_script.encode("ascii"))
|
||||
f.flush()
|
||||
out = subprocess.check_output([sys.executable,
|
||||
f.name]).decode("ascii")
|
||||
out = subprocess.check_output([sys.executable, f.name]).decode("ascii")
|
||||
return out
|
||||
|
||||
|
||||
class MultiNodeTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
out = run_and_get_output(["ray", "start", "--head"])
|
||||
# Get the redis address from the output.
|
||||
redis_substring_prefix = "redis_address=\""
|
||||
redis_address_location = (out.find(redis_substring_prefix) +
|
||||
len(redis_substring_prefix))
|
||||
redis_address_location = (
|
||||
out.find(redis_substring_prefix) + len(redis_substring_prefix))
|
||||
redis_address = out[redis_address_location:]
|
||||
self.redis_address = redis_address.split("\"")[0]
|
||||
|
||||
|
@ -196,7 +194,6 @@ print("success")
|
|||
|
||||
|
||||
class StartRayScriptTest(unittest.TestCase):
|
||||
|
||||
def testCallingStartRayHead(self):
|
||||
# Test that we can call start-ray.sh with various command line
|
||||
# parameters. TODO(rkn): This test only tests the --head code path. We
|
||||
|
@ -207,69 +204,65 @@ class StartRayScriptTest(unittest.TestCase):
|
|||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with a number of workers specified.
|
||||
run_and_get_output(["ray", "start", "--head", "--num-workers",
|
||||
"20"])
|
||||
run_and_get_output(["ray", "start", "--head", "--num-workers", "20"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with a redis port specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--redis-port", "6379"])
|
||||
run_and_get_output(["ray", "start", "--head", "--redis-port", "6379"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with redis shard ports specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--redis-shard-ports", "6380,6381,6382"])
|
||||
run_and_get_output([
|
||||
"ray", "start", "--head", "--redis-shard-ports", "6380,6381,6382"
|
||||
])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with a node IP address specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--node-ip-address", "127.0.0.1"])
|
||||
run_and_get_output(
|
||||
["ray", "start", "--head", "--node-ip-address", "127.0.0.1"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with an object manager port specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--object-manager-port", "12345"])
|
||||
run_and_get_output(
|
||||
["ray", "start", "--head", "--object-manager-port", "12345"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with the number of CPUs specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--num-cpus", "100"])
|
||||
run_and_get_output(["ray", "start", "--head", "--num-cpus", "100"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with the number of GPUs specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--num-gpus", "100"])
|
||||
run_and_get_output(["ray", "start", "--head", "--num-gpus", "100"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with the max redis clients specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--redis-max-clients", "100"])
|
||||
run_and_get_output(
|
||||
["ray", "start", "--head", "--redis-max-clients", "100"])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with all arguments specified.
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--num-workers", "20",
|
||||
"--redis-port", "6379",
|
||||
"--redis-shard-ports", "6380,6381,6382",
|
||||
"--object-manager-port", "12345",
|
||||
"--num-cpus", "100",
|
||||
"--num-gpus", "0",
|
||||
"--redis-max-clients", "100",
|
||||
"--resources", "{\"Custom\": 1}"])
|
||||
run_and_get_output([
|
||||
"ray", "start", "--head", "--num-workers", "20", "--redis-port",
|
||||
"6379", "--redis-shard-ports", "6380,6381,6382",
|
||||
"--object-manager-port", "12345", "--num-cpus", "100",
|
||||
"--num-gpus", "0", "--redis-max-clients", "100", "--resources",
|
||||
"{\"Custom\": 1}"
|
||||
])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
# Test starting Ray with invalid arguments.
|
||||
with self.assertRaises(Exception):
|
||||
run_and_get_output(["ray", "start", "--head",
|
||||
"--redis-address", "127.0.0.1:6379"])
|
||||
run_and_get_output([
|
||||
"ray", "start", "--head", "--redis-address", "127.0.0.1:6379"
|
||||
])
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
def testUsingHostnames(self):
|
||||
# Start the Ray processes on this machine.
|
||||
run_and_get_output(
|
||||
["ray", "start", "--head",
|
||||
"--node-ip-address=localhost",
|
||||
"--redis-port=6379"])
|
||||
run_and_get_output([
|
||||
"ray", "start", "--head", "--node-ip-address=localhost",
|
||||
"--redis-port=6379"
|
||||
])
|
||||
|
||||
ray.init(node_ip_address="localhost", redis_address="localhost:6379")
|
||||
|
||||
|
|
120
test/runtest.py
120
test/runtest.py
|
@ -184,9 +184,9 @@ DICT_OBJECTS = (
|
|||
[{
|
||||
obj: obj
|
||||
} for obj in PRIMITIVE_OBJECTS
|
||||
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
|
||||
0:
|
||||
obj
|
||||
if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] +
|
||||
[{
|
||||
0: obj
|
||||
} for obj in BASE_OBJECTS] + [{
|
||||
Foo(123): Foo(456)
|
||||
}])
|
||||
|
@ -359,25 +359,29 @@ class APITest(unittest.TestCase):
|
|||
def custom_deserializer(serialized_obj):
|
||||
return serialized_obj, "string2"
|
||||
|
||||
ray.register_custom_serializer(Foo, serializer=custom_serializer,
|
||||
ray.register_custom_serializer(
|
||||
Foo,
|
||||
serializer=custom_serializer,
|
||||
deserializer=custom_deserializer)
|
||||
|
||||
self.assertEqual(ray.get(ray.put(Foo())),
|
||||
((3, "string1", Foo.__name__), "string2"))
|
||||
self.assertEqual(
|
||||
ray.get(ray.put(Foo())), ((3, "string1", Foo.__name__), "string2"))
|
||||
|
||||
class Bar(object):
|
||||
def __init__(self):
|
||||
self.x = 3
|
||||
|
||||
ray.register_custom_serializer(Bar, serializer=custom_serializer,
|
||||
ray.register_custom_serializer(
|
||||
Bar,
|
||||
serializer=custom_serializer,
|
||||
deserializer=custom_deserializer)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return Bar()
|
||||
|
||||
self.assertEqual(ray.get(f.remote()),
|
||||
((3, "string1", Bar.__name__), "string2"))
|
||||
self.assertEqual(
|
||||
ray.get(f.remote()), ((3, "string1", Bar.__name__), "string2"))
|
||||
|
||||
def testRegisterClass(self):
|
||||
self.init_ray(num_workers=2)
|
||||
|
@ -700,10 +704,10 @@ class APITest(unittest.TestCase):
|
|||
assert ray.get(f._submit(args=[1], num_return_vals=1)) == [0]
|
||||
assert ray.get(f._submit(args=[2], num_return_vals=2)) == [0, 1]
|
||||
assert ray.get(f._submit(args=[3], num_return_vals=3)) == [0, 1, 2]
|
||||
assert ray.get(g._submit(args=[],
|
||||
num_cpus=1,
|
||||
num_gpus=1,
|
||||
resources={"Custom": 1})) == [0]
|
||||
assert ray.get(
|
||||
g._submit(
|
||||
args=[], num_cpus=1, num_gpus=1, resources={"Custom":
|
||||
1})) == [0]
|
||||
|
||||
def testGetMultiple(self):
|
||||
self.init_ray()
|
||||
|
@ -1234,8 +1238,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 0
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1245,8 +1249,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 1
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1256,8 +1260,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 2
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1267,8 +1271,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 3
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1278,8 +1282,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 4
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1289,8 +1293,8 @@ class ResourcesTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 5
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
for gpu_id in gpu_ids:
|
||||
assert gpu_id in range(num_gpus)
|
||||
return gpu_ids
|
||||
|
@ -1342,16 +1346,16 @@ class ResourcesTest(unittest.TestCase):
|
|||
def __init__(self):
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 0
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
# Set self.x to make sure that we got here.
|
||||
self.x = 1
|
||||
|
||||
def test(self):
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 0
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
return self.x
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
|
@ -1359,16 +1363,16 @@ class ResourcesTest(unittest.TestCase):
|
|||
def __init__(self):
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 1
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
# Set self.x to make sure that we got here.
|
||||
self.x = 1
|
||||
|
||||
def test(self):
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 1
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] ==
|
||||
",".join([str(i) for i in gpu_ids]))
|
||||
assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
|
||||
[str(i) for i in gpu_ids]))
|
||||
return self.x
|
||||
|
||||
a0 = Actor0.remote()
|
||||
|
@ -1379,9 +1383,7 @@ class ResourcesTest(unittest.TestCase):
|
|||
|
||||
def testZeroCPUs(self):
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_cpus=[0, 2])
|
||||
start_ray_local=True, num_local_schedulers=2, num_cpus=[0, 2])
|
||||
|
||||
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
|
||||
|
||||
|
@ -1484,9 +1486,9 @@ class ResourcesTest(unittest.TestCase):
|
|||
elif name == "run_on_2":
|
||||
self.assertIn(result, [store_names[2]])
|
||||
elif name == "run_on_0_1_2":
|
||||
self.assertIn(result, [
|
||||
store_names[0], store_names[1], store_names[2]
|
||||
])
|
||||
self.assertIn(
|
||||
result,
|
||||
[store_names[0], store_names[1], store_names[2]])
|
||||
elif name == "run_on_1_2":
|
||||
self.assertIn(result, [store_names[1], store_names[2]])
|
||||
elif name == "run_on_0_2":
|
||||
|
@ -1518,7 +1520,11 @@ class ResourcesTest(unittest.TestCase):
|
|||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_cpus=[3, 3],
|
||||
resources=[{"CustomResource": 0}, {"CustomResource": 1}])
|
||||
resources=[{
|
||||
"CustomResource": 0
|
||||
}, {
|
||||
"CustomResource": 1
|
||||
}])
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
|
@ -1554,8 +1560,13 @@ class ResourcesTest(unittest.TestCase):
|
|||
start_ray_local=True,
|
||||
num_local_schedulers=2,
|
||||
num_cpus=[3, 3],
|
||||
resources=[{"CustomResource1": 1, "CustomResource2": 2},
|
||||
{"CustomResource1": 3, "CustomResource2": 4}])
|
||||
resources=[{
|
||||
"CustomResource1": 1,
|
||||
"CustomResource2": 2
|
||||
}, {
|
||||
"CustomResource1": 3,
|
||||
"CustomResource2": 4
|
||||
}])
|
||||
|
||||
@ray.remote(resources={"CustomResource1": 1})
|
||||
def f():
|
||||
|
@ -1595,14 +1606,16 @@ class ResourcesTest(unittest.TestCase):
|
|||
|
||||
# Make sure that tasks with unsatisfied custom resource requirements do
|
||||
# not get scheduled.
|
||||
ready_ids, remaining_ids = ray.wait([j.remote(), k.remote()],
|
||||
timeout=500)
|
||||
ready_ids, remaining_ids = ray.wait(
|
||||
[j.remote(), k.remote()], timeout=500)
|
||||
self.assertEqual(ready_ids, [])
|
||||
|
||||
def testManyCustomResources(self):
|
||||
num_custom_resources = 10000
|
||||
total_resources = {str(i): np.random.randint(1, 7)
|
||||
for i in range(num_custom_resources)}
|
||||
total_resources = {
|
||||
str(i): np.random.randint(1, 7)
|
||||
for i in range(num_custom_resources)
|
||||
}
|
||||
ray.init(num_cpus=5, resources=total_resources)
|
||||
|
||||
def f():
|
||||
|
@ -1613,8 +1626,10 @@ class ResourcesTest(unittest.TestCase):
|
|||
num_resources = np.random.randint(0, num_custom_resources + 1)
|
||||
permuted_resources = np.random.permutation(
|
||||
num_custom_resources)[:num_resources]
|
||||
random_resources = {str(i): total_resources[str(i)]
|
||||
for i in permuted_resources}
|
||||
random_resources = {
|
||||
str(i): total_resources[str(i)]
|
||||
for i in permuted_resources
|
||||
}
|
||||
remote_function = ray.remote(resources=random_resources)(f)
|
||||
remote_functions.append(remote_function)
|
||||
|
||||
|
@ -1634,8 +1649,7 @@ class CudaVisibleDevicesTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
# Record the curent value of this environment variable so that we can
|
||||
# reset it after the test.
|
||||
self.original_gpu_ids = os.environ.get(
|
||||
"CUDA_VISIBLE_DEVICES", None)
|
||||
self.original_gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
@ -2095,9 +2109,9 @@ class GlobalStateAPI(unittest.TestCase):
|
|||
for object_info in object_table.values():
|
||||
if len(object_info) != 5:
|
||||
tables_ready = False
|
||||
if (object_info["ManagerIDs"] is None or
|
||||
object_info["DataSize"] == -1 or
|
||||
object_info["Hash"] == ""):
|
||||
if (object_info["ManagerIDs"] is None
|
||||
or object_info["DataSize"] == -1
|
||||
or object_info["Hash"] == ""):
|
||||
tables_ready = False
|
||||
|
||||
if len(task_table) != 10 + 1:
|
||||
|
|
|
@ -10,12 +10,13 @@ import time
|
|||
|
||||
|
||||
class TaskTests(unittest.TestCase):
|
||||
|
||||
def testSubmittingTasks(self):
|
||||
for num_local_schedulers in [1, 4]:
|
||||
for num_workers_per_scheduler in [4]:
|
||||
num_workers = num_local_schedulers * num_workers_per_scheduler
|
||||
ray.worker._init(start_ray_local=True, num_workers=num_workers,
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=100)
|
||||
|
||||
|
@ -42,7 +43,9 @@ class TaskTests(unittest.TestCase):
|
|||
for num_local_schedulers in [1, 4]:
|
||||
for num_workers_per_scheduler in [4]:
|
||||
num_workers = num_local_schedulers * num_workers_per_scheduler
|
||||
ray.worker._init(start_ray_local=True, num_workers=num_workers,
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=100)
|
||||
|
||||
|
@ -89,7 +92,7 @@ class TaskTests(unittest.TestCase):
|
|||
ray.init(num_workers=1)
|
||||
|
||||
for n in range(8):
|
||||
x = np.zeros(10 ** n)
|
||||
x = np.zeros(10**n)
|
||||
|
||||
for _ in range(100):
|
||||
ray.put(x)
|
||||
|
@ -108,7 +111,7 @@ class TaskTests(unittest.TestCase):
|
|||
def f():
|
||||
return 1
|
||||
|
||||
n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster.
|
||||
n = 10**4 # TODO(pcm): replace by 10 ** 5 once this is faster.
|
||||
lst = ray.get([f.remote() for _ in range(n)])
|
||||
self.assertEqual(lst, n * [1])
|
||||
|
||||
|
@ -119,7 +122,9 @@ class TaskTests(unittest.TestCase):
|
|||
for num_local_schedulers in [1, 4]:
|
||||
for num_workers_per_scheduler in [4]:
|
||||
num_workers = num_local_schedulers * num_workers_per_scheduler
|
||||
ray.worker._init(start_ray_local=True, num_workers=num_workers,
|
||||
ray.worker._init(
|
||||
start_ray_local=True,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=100)
|
||||
|
||||
|
@ -138,8 +143,10 @@ class TaskTests(unittest.TestCase):
|
|||
time.sleep(x)
|
||||
|
||||
for i in range(1, 5):
|
||||
x_ids = [g.remote(np.random.uniform(0, i))
|
||||
for _ in range(2 * num_workers)]
|
||||
x_ids = [
|
||||
g.remote(np.random.uniform(0, i))
|
||||
for _ in range(2 * num_workers)
|
||||
]
|
||||
ray.wait(x_ids, num_returns=len(x_ids))
|
||||
|
||||
self.assertTrue(ray.services.all_processes_alive())
|
||||
|
@ -159,18 +166,20 @@ class ReconstructionTests(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
|
||||
# Start the Plasma store instances with a total of 1GB memory.
|
||||
self.plasma_store_memory = 10 ** 9
|
||||
self.plasma_store_memory = 10**9
|
||||
plasma_addresses = []
|
||||
objstore_memory = (self.plasma_store_memory //
|
||||
self.num_local_schedulers)
|
||||
objstore_memory = (
|
||||
self.plasma_store_memory // self.num_local_schedulers)
|
||||
for i in range(self.num_local_schedulers):
|
||||
store_stdout_file, store_stderr_file = ray.services.new_log_files(
|
||||
"plasma_store_{}".format(i), True)
|
||||
manager_stdout_file, manager_stderr_file = (
|
||||
ray.services.new_log_files("plasma_manager_{}"
|
||||
.format(i), True))
|
||||
plasma_addresses.append(ray.services.start_objstore(
|
||||
node_ip_address, redis_address,
|
||||
ray.services.new_log_files("plasma_manager_{}".format(i),
|
||||
True))
|
||||
plasma_addresses.append(
|
||||
ray.services.start_objstore(
|
||||
node_ip_address,
|
||||
redis_address,
|
||||
objstore_memory=objstore_memory,
|
||||
store_stdout_file=store_stdout_file,
|
||||
store_stderr_file=store_stderr_file,
|
||||
|
@ -178,10 +187,14 @@ class ReconstructionTests(unittest.TestCase):
|
|||
manager_stderr_file=manager_stderr_file))
|
||||
|
||||
# Start the rest of the services in the Ray cluster.
|
||||
address_info = {"redis_address": redis_address,
|
||||
address_info = {
|
||||
"redis_address": redis_address,
|
||||
"redis_shards": redis_shards,
|
||||
"object_store_addresses": plasma_addresses}
|
||||
ray.worker._init(address_info=address_info, start_ray_local=True,
|
||||
"object_store_addresses": plasma_addresses
|
||||
}
|
||||
ray.worker._init(
|
||||
address_info=address_info,
|
||||
start_ray_local=True,
|
||||
num_workers=1,
|
||||
num_local_schedulers=self.num_local_schedulers,
|
||||
num_cpus=[1] * self.num_local_schedulers,
|
||||
|
@ -197,8 +210,8 @@ class ReconstructionTests(unittest.TestCase):
|
|||
state._initialize_global_state(self.redis_ip_address, self.redis_port)
|
||||
if os.environ.get('RAY_USE_NEW_GCS', False):
|
||||
tasks = state.task_table()
|
||||
local_scheduler_ids = set(task["LocalSchedulerID"] for task in
|
||||
tasks.values())
|
||||
local_scheduler_ids = set(
|
||||
task["LocalSchedulerID"] for task in tasks.values())
|
||||
|
||||
# Make sure that all nodes in the cluster were used by checking that
|
||||
# the set of local scheduler IDs that had a task scheduled or submitted
|
||||
|
@ -208,8 +221,8 @@ class ReconstructionTests(unittest.TestCase):
|
|||
# with the driver task, since it is not scheduled by a particular local
|
||||
# scheduler.
|
||||
if os.environ.get('RAY_USE_NEW_GCS', False):
|
||||
self.assertEqual(len(local_scheduler_ids),
|
||||
self.num_local_schedulers + 1)
|
||||
self.assertEqual(
|
||||
len(local_scheduler_ids), self.num_local_schedulers + 1)
|
||||
|
||||
# Clean up the Ray cluster.
|
||||
ray.worker.cleanup()
|
||||
|
@ -254,8 +267,7 @@ class ReconstructionTests(unittest.TestCase):
|
|||
del values
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Failing with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.")
|
||||
def testRecursive(self):
|
||||
# Define the size of one task's return argument so that the combined
|
||||
# sum of all objects' sizes is at least twice the plasma stores'
|
||||
|
@ -308,8 +320,7 @@ class ReconstructionTests(unittest.TestCase):
|
|||
del values
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Failing with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.")
|
||||
def testMultipleRecursive(self):
|
||||
# Define the size of one task's return argument so that the combined
|
||||
# sum of all objects' sizes is at least twice the plasma stores'
|
||||
|
@ -375,8 +386,7 @@ class ReconstructionTests(unittest.TestCase):
|
|||
return errors
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testNondeterministicTask(self):
|
||||
# Define the size of one task's return argument so that the combined
|
||||
# sum of all objects' sizes is at least twice the plasma stores'
|
||||
|
@ -432,14 +442,14 @@ class ReconstructionTests(unittest.TestCase):
|
|||
# reexecuted.
|
||||
min_errors = 1
|
||||
return len(errors) >= min_errors
|
||||
|
||||
errors = self.wait_for_errors(error_check)
|
||||
# Make sure all the errors have the correct type.
|
||||
self.assertTrue(all(error[b"type"] == b"object_hash_mismatch"
|
||||
for error in errors))
|
||||
self.assertTrue(
|
||||
all(error[b"type"] == b"object_hash_mismatch" for error in errors))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
"Hanging with new GCS API.")
|
||||
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
|
||||
def testDriverPutErrors(self):
|
||||
# Define the size of one task's return argument so that the combined
|
||||
# sum of all objects' sizes is at least twice the plasma stores'
|
||||
|
@ -479,9 +489,10 @@ class ReconstructionTests(unittest.TestCase):
|
|||
|
||||
def error_check(errors):
|
||||
return len(errors) > 1
|
||||
|
||||
errors = self.wait_for_errors(error_check)
|
||||
self.assertTrue(all(error[b"type"] == b"put_reconstruction"
|
||||
for error in errors))
|
||||
self.assertTrue(
|
||||
all(error[b"type"] == b"put_reconstruction" for error in errors))
|
||||
|
||||
|
||||
class ReconstructionTestsMultinode(ReconstructionTests):
|
||||
|
@ -490,6 +501,7 @@ class ReconstructionTestsMultinode(ReconstructionTests):
|
|||
# one worker each.
|
||||
num_local_schedulers = 4
|
||||
|
||||
|
||||
# NOTE(swang): This test tries to launch 1000 workers and breaks.
|
||||
# class WorkerPoolTests(unittest.TestCase):
|
||||
#
|
||||
|
@ -512,6 +524,5 @@ class ReconstructionTestsMultinode(ReconstructionTests):
|
|||
# ray.get([g.remote(i) for i in range(1000)])
|
||||
# ray.worker.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -23,7 +23,6 @@ def make_linear_network(w_name=None, b_name=None):
|
|||
|
||||
|
||||
class LossActor(object):
|
||||
|
||||
def __init__(self, use_loss=True):
|
||||
# Uses a separate graph for each network.
|
||||
with tf.Graph().as_default():
|
||||
|
@ -32,10 +31,8 @@ class LossActor(object):
|
|||
loss, init, _, _ = make_linear_network()
|
||||
sess = tf.Session()
|
||||
# Additional code for setting and getting the weights.
|
||||
weights = ray.experimental.TensorFlowVariables(loss if use_loss
|
||||
else None,
|
||||
sess,
|
||||
input_variables=var)
|
||||
weights = ray.experimental.TensorFlowVariables(
|
||||
loss if use_loss else None, sess, input_variables=var)
|
||||
# Return all of the data needed to use the network.
|
||||
self.values = [weights, init, sess]
|
||||
sess.run(init)
|
||||
|
@ -49,7 +46,6 @@ class LossActor(object):
|
|||
|
||||
|
||||
class NetActor(object):
|
||||
|
||||
def __init__(self):
|
||||
# Uses a separate graph for each network.
|
||||
with tf.Graph().as_default():
|
||||
|
@ -71,7 +67,6 @@ class NetActor(object):
|
|||
|
||||
|
||||
class TrainActor(object):
|
||||
|
||||
def __init__(self):
|
||||
# Almost the same as above, but now returns the placeholders and
|
||||
# gradient.
|
||||
|
@ -82,16 +77,17 @@ class TrainActor(object):
|
|||
optimizer = tf.train.GradientDescentOptimizer(0.9)
|
||||
grads = optimizer.compute_gradients(loss)
|
||||
train = optimizer.apply_gradients(grads)
|
||||
self.values = [loss, variables, init, sess, grads, train,
|
||||
[x_data, y_data]]
|
||||
self.values = [
|
||||
loss, variables, init, sess, grads, train, [x_data, y_data]
|
||||
]
|
||||
sess.run(init)
|
||||
|
||||
def training_step(self, weights):
|
||||
_, variables, _, sess, grads, _, placeholders = self.values
|
||||
variables.set_weights(weights)
|
||||
return sess.run([grad[0] for grad in grads],
|
||||
feed_dict=dict(zip(placeholders,
|
||||
[[1] * 100, [2] * 100])))
|
||||
return sess.run(
|
||||
[grad[0] for grad in grads],
|
||||
feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100])))
|
||||
|
||||
def get_weights(self):
|
||||
return self.values[1].get_weights()
|
||||
|
@ -216,8 +212,8 @@ class TensorFlowTest(unittest.TestCase):
|
|||
net2 = ray.remote(NetActor).remote()
|
||||
weights2 = ray.get(net2.get_weights.remote())
|
||||
|
||||
new_weights2 = ray.get(net2.set_and_get_weights.remote(
|
||||
net2.get_weights.remote()))
|
||||
new_weights2 = ray.get(
|
||||
net2.set_and_get_weights.remote(net2.get_weights.remote()))
|
||||
self.assertEqual(weights2, new_weights2)
|
||||
|
||||
def testVariablesControlDependencies(self):
|
||||
|
@ -247,22 +243,26 @@ class TensorFlowTest(unittest.TestCase):
|
|||
net_values = TrainActor().values
|
||||
loss, variables, _, sess, grads, train, placeholders = net_values
|
||||
|
||||
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
|
||||
[[2] * 100,
|
||||
[4] * 100])))
|
||||
before_acc = sess.run(
|
||||
loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])))
|
||||
|
||||
for _ in range(3):
|
||||
gradients_list = ray.get(
|
||||
[net.training_step.remote(variables.get_weights())
|
||||
for _ in range(2)])
|
||||
mean_grads = [sum([gradients[i] for gradients in gradients_list]) /
|
||||
len(gradients_list) for i
|
||||
in range(len(gradients_list[0]))]
|
||||
feed_dict = {grad[0]: mean_grad for (grad, mean_grad)
|
||||
in zip(grads, mean_grads)}
|
||||
gradients_list = ray.get([
|
||||
net.training_step.remote(variables.get_weights())
|
||||
for _ in range(2)
|
||||
])
|
||||
mean_grads = [
|
||||
sum([gradients[i]
|
||||
for gradients in gradients_list]) / len(gradients_list)
|
||||
for i in range(len(gradients_list[0]))
|
||||
]
|
||||
feed_dict = {
|
||||
grad[0]: mean_grad
|
||||
for (grad, mean_grad) in zip(grads, mean_grads)
|
||||
}
|
||||
sess.run(train, feed_dict=feed_dict)
|
||||
after_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
|
||||
[[2] * 100, [4] * 100])))
|
||||
after_acc = sess.run(
|
||||
loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])))
|
||||
self.assertTrue(before_acc < after_acc)
|
||||
|
||||
|
||||
|
|
|
@ -57,12 +57,11 @@ def test_put_api(ray_start):
|
|||
|
||||
# Test putting object IDs.
|
||||
x_id = ray.put(0)
|
||||
for obj in [[x_id], (x_id,), {x_id: x_id}]:
|
||||
for obj in [[x_id], (x_id, ), {x_id: x_id}]:
|
||||
assert ray.get(ray.put(obj)) == obj
|
||||
|
||||
|
||||
def test_actor_api(ray_start):
|
||||
|
||||
@ray.remote
|
||||
class Foo(object):
|
||||
def __init__(self, val):
|
||||
|
|
Loading…
Add table
Reference in a new issue