mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] move force_on_current_node to ml_utils (#20211)
This commit is contained in:
parent
143d23a278
commit
790e22f9ad
10 changed files with 58 additions and 60 deletions
|
@ -362,8 +362,6 @@ Utilities
|
|||
|
||||
.. autofunction:: ray.tune.utils.validate_save_restore
|
||||
|
||||
.. autofunction:: ray.tune.utils.force_on_current_node
|
||||
|
||||
|
||||
.. _tune-ddp-doc:
|
||||
|
||||
|
|
|
@ -226,7 +226,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
|||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils.util import force_on_current_node
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(best_trial))
|
||||
else:
|
||||
|
|
|
@ -141,7 +141,7 @@ if __name__ == "__main__":
|
|||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils.util import force_on_current_node
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(analysis))
|
||||
|
|
|
@ -297,7 +297,7 @@ if __name__ == "__main__":
|
|||
# should be wrapped in a task so it will execute on the server.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils import force_on_current_node
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
remote_fn = force_on_current_node(
|
||||
ray.remote(get_best_model_checkpoint))
|
||||
best_bst = ray.get(remote_fn.remote(analysis))
|
||||
|
|
|
@ -90,7 +90,7 @@ if __name__ == "__main__":
|
|||
# should be wrapped in a task so it will execute on the server.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils import force_on_current_node
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
remote_fn = force_on_current_node(
|
||||
ray.remote(get_best_model_checkpoint))
|
||||
best_bst = ray.get(remote_fn.remote(analysis))
|
||||
|
|
|
@ -11,6 +11,7 @@ import warnings
|
|||
|
||||
import ray
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
from ray.util.queue import Queue, Empty
|
||||
|
||||
from ray.tune.analysis import ExperimentAnalysis
|
||||
|
@ -35,7 +36,6 @@ from ray.tune.trial import Trial
|
|||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
|
||||
from ray.tune.utils import force_on_current_node
|
||||
|
||||
# Must come last to avoid circular imports
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from ray.tune.utils.util import (
|
||||
deep_update, date_str, find_free_port, flatten_dict, force_on_current_node,
|
||||
get_pinned_object, merge_dicts, pin_in_object_store, unflattened_lookup,
|
||||
UtilMonitor, validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
deep_update, date_str, find_free_port, flatten_dict, get_pinned_object,
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor,
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
detect_checkpoint_function, detect_reporter, detect_config_single,
|
||||
wait_for_gpu)
|
||||
|
||||
__all__ = [
|
||||
"deep_update", "date_str", "find_free_port", "flatten_dict",
|
||||
"force_on_current_node", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
|
||||
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
|
||||
"wait_for_gpu"
|
||||
"get_pinned_object", "merge_dicts", "pin_in_object_store",
|
||||
"unflattened_lookup", "UtilMonitor", "validate_save_restore",
|
||||
"warn_if_slow", "diagnose_serialization", "detect_checkpoint_function",
|
||||
"detect_reporter", "detect_config_single", "wait_for_gpu"
|
||||
]
|
||||
|
|
|
@ -624,50 +624,6 @@ def validate_warmstart(parameter_names: List[str],
|
|||
" do not match.")
|
||||
|
||||
|
||||
def get_current_node_resource_key() -> str:
|
||||
"""Get the Ray resource key for current node.
|
||||
It can be used for actor placement.
|
||||
|
||||
If using Ray Client, this will return the resource key for the node that
|
||||
is running the client server.
|
||||
|
||||
Returns:
|
||||
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
|
||||
"""
|
||||
current_node_id = ray.get_runtime_context().node_id.hex()
|
||||
for node in ray.nodes():
|
||||
if node["NodeID"] == current_node_id:
|
||||
# Found the node.
|
||||
for key in node["Resources"].keys():
|
||||
if key.startswith("node:"):
|
||||
return key
|
||||
else:
|
||||
raise ValueError("Cannot found the node dictionary for current node.")
|
||||
|
||||
|
||||
def force_on_current_node(task_or_actor=None):
|
||||
"""Given a task or actor, place it on the current node.
|
||||
|
||||
If using Ray Client, the current node is the client server node.
|
||||
|
||||
Args:
|
||||
task_or_actor: A Ray remote function or class to place on the
|
||||
current node. If None, returns the options dict to pass to
|
||||
another actor.
|
||||
|
||||
Returns:
|
||||
The provided task or actor, but with options modified to force
|
||||
placement on the current node.
|
||||
"""
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
options = {"resources": {node_resource_key: 0.01}}
|
||||
|
||||
if task_or_actor is None:
|
||||
return options
|
||||
|
||||
return task_or_actor.options(**options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
|
45
python/ray/util/ml_utils/node.py
Normal file
45
python/ray/util/ml_utils/node.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import ray
|
||||
|
||||
|
||||
def get_current_node_resource_key() -> str:
|
||||
"""Get the Ray resource key for current node.
|
||||
It can be used for actor placement.
|
||||
|
||||
If using Ray Client, this will return the resource key for the node that
|
||||
is running the client server.
|
||||
|
||||
Returns:
|
||||
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
|
||||
"""
|
||||
current_node_id = ray.get_runtime_context().node_id.hex()
|
||||
for node in ray.nodes():
|
||||
if node["NodeID"] == current_node_id:
|
||||
# Found the node.
|
||||
for key in node["Resources"].keys():
|
||||
if key.startswith("node:"):
|
||||
return key
|
||||
else:
|
||||
raise ValueError("Cannot found the node dictionary for current node.")
|
||||
|
||||
|
||||
def force_on_current_node(task_or_actor=None):
|
||||
"""Given a task or actor, place it on the current node.
|
||||
|
||||
If using Ray Client, the current node is the client server node.
|
||||
|
||||
Args:
|
||||
task_or_actor: A Ray remote function or class to place on the
|
||||
current node. If None, returns the options dict to pass to
|
||||
another actor.
|
||||
|
||||
Returns:
|
||||
The provided task or actor, but with options modified to force
|
||||
placement on the current node.
|
||||
"""
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
options = {"resources": {node_resource_key: 0.01}}
|
||||
|
||||
if task_or_actor is None:
|
||||
return options
|
||||
|
||||
return task_or_actor.options(**options)
|
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||
import torchvision.transforms as transforms
|
||||
from filelock import FileLock
|
||||
from ray import serve, tune
|
||||
from ray.tune.utils import force_on_current_node
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import override
|
||||
|
|
Loading…
Add table
Reference in a new issue