[tune] move force_on_current_node to ml_utils (#20211)

This commit is contained in:
matthewdeng 2021-11-10 10:21:24 -08:00 committed by GitHub
parent 143d23a278
commit 790e22f9ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 58 additions and 60 deletions

View file

@ -362,8 +362,6 @@ Utilities
.. autofunction:: ray.tune.utils.validate_save_restore
.. autofunction:: ray.tune.utils.force_on_current_node
.. _tune-ddp-doc:

View file

@ -226,7 +226,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
# happens on the server. So we wrap `test_best_model` in a Ray task.
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils.util import force_on_current_node
from ray.util.ml_utils.node import force_on_current_node
remote_fn = force_on_current_node(ray.remote(test_best_model))
ray.get(remote_fn.remote(best_trial))
else:

View file

@ -141,7 +141,7 @@ if __name__ == "__main__":
# happens on the server. So we wrap `test_best_model` in a Ray task.
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils.util import force_on_current_node
from ray.util.ml_utils.node import force_on_current_node
remote_fn = force_on_current_node(ray.remote(test_best_model))
ray.get(remote_fn.remote(analysis))

View file

@ -297,7 +297,7 @@ if __name__ == "__main__":
# should be wrapped in a task so it will execute on the server.
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils import force_on_current_node
from ray.util.ml_utils.node import force_on_current_node
remote_fn = force_on_current_node(
ray.remote(get_best_model_checkpoint))
best_bst = ray.get(remote_fn.remote(analysis))

View file

@ -90,7 +90,7 @@ if __name__ == "__main__":
# should be wrapped in a task so it will execute on the server.
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils import force_on_current_node
from ray.util.ml_utils.node import force_on_current_node
remote_fn = force_on_current_node(
ray.remote(get_best_model_checkpoint))
best_bst = ray.get(remote_fn.remote(analysis))

View file

@ -11,6 +11,7 @@ import warnings
import ray
from ray.util.annotations import PublicAPI
from ray.util.ml_utils.node import force_on_current_node
from ray.util.queue import Queue, Empty
from ray.tune.analysis import ExperimentAnalysis
@ -35,7 +36,6 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
from ray.tune.utils import force_on_current_node
# Must come last to avoid circular imports
from ray.tune.schedulers import FIFOScheduler, TrialScheduler

View file

@ -1,15 +1,14 @@
from ray.tune.utils.util import (
deep_update, date_str, find_free_port, flatten_dict, force_on_current_node,
get_pinned_object, merge_dicts, pin_in_object_store, unflattened_lookup,
UtilMonitor, validate_save_restore, warn_if_slow, diagnose_serialization,
deep_update, date_str, find_free_port, flatten_dict, get_pinned_object,
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor,
validate_save_restore, warn_if_slow, diagnose_serialization,
detect_checkpoint_function, detect_reporter, detect_config_single,
wait_for_gpu)
__all__ = [
"deep_update", "date_str", "find_free_port", "flatten_dict",
"force_on_current_node", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
"wait_for_gpu"
"get_pinned_object", "merge_dicts", "pin_in_object_store",
"unflattened_lookup", "UtilMonitor", "validate_save_restore",
"warn_if_slow", "diagnose_serialization", "detect_checkpoint_function",
"detect_reporter", "detect_config_single", "wait_for_gpu"
]

View file

@ -624,50 +624,6 @@ def validate_warmstart(parameter_names: List[str],
" do not match.")
def get_current_node_resource_key() -> str:
"""Get the Ray resource key for current node.
It can be used for actor placement.
If using Ray Client, this will return the resource key for the node that
is running the client server.
Returns:
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
"""
current_node_id = ray.get_runtime_context().node_id.hex()
for node in ray.nodes():
if node["NodeID"] == current_node_id:
# Found the node.
for key in node["Resources"].keys():
if key.startswith("node:"):
return key
else:
raise ValueError("Cannot found the node dictionary for current node.")
def force_on_current_node(task_or_actor=None):
"""Given a task or actor, place it on the current node.
If using Ray Client, the current node is the client server node.
Args:
task_or_actor: A Ray remote function or class to place on the
current node. If None, returns the options dict to pass to
another actor.
Returns:
The provided task or actor, but with options modified to force
placement on the current node.
"""
node_resource_key = get_current_node_resource_key()
options = {"resources": {node_resource_key: 0.01}}
if task_or_actor is None:
return options
return task_or_actor.options(**options)
if __name__ == "__main__":
ray.init()
X = pin_in_object_store("hello")

View file

@ -0,0 +1,45 @@
import ray
def get_current_node_resource_key() -> str:
"""Get the Ray resource key for current node.
It can be used for actor placement.
If using Ray Client, this will return the resource key for the node that
is running the client server.
Returns:
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
"""
current_node_id = ray.get_runtime_context().node_id.hex()
for node in ray.nodes():
if node["NodeID"] == current_node_id:
# Found the node.
for key in node["Resources"].keys():
if key.startswith("node:"):
return key
else:
raise ValueError("Cannot found the node dictionary for current node.")
def force_on_current_node(task_or_actor=None):
"""Given a task or actor, place it on the current node.
If using Ray Client, the current node is the client server node.
Args:
task_or_actor: A Ray remote function or class to place on the
current node. If None, returns the options dict to pass to
another actor.
Returns:
The provided task or actor, but with options modified to force
placement on the current node.
"""
node_resource_key = get_current_node_resource_key()
options = {"resources": {node_resource_key: 0.01}}
if task_or_actor is None:
return options
return task_or_actor.options(**options)

View file

@ -10,7 +10,7 @@ import torch.nn as nn
import torchvision.transforms as transforms
from filelock import FileLock
from ray import serve, tune
from ray.tune.utils import force_on_current_node
from ray.util.ml_utils.node import force_on_current_node
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import override