mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Make fake node provider thread safe (#20591)
We may have multiple NodeLauncher threads access the same node provider so it should be thread safe.
This commit is contained in:
parent
12c11894e8
commit
255bdc8fb1
1 changed files with 53 additions and 42 deletions
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
import json
|
||||
from threading import RLock
|
||||
|
||||
import ray
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
|
@ -25,6 +26,7 @@ class FakeMultiNodeProvider(NodeProvider):
|
|||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.lock = RLock()
|
||||
if "RAY_FAKE_CLUSTER" not in os.environ:
|
||||
raise RuntimeError(
|
||||
"FakeMultiNodeProvider requires ray to be started with "
|
||||
|
@ -47,25 +49,29 @@ class FakeMultiNodeProvider(NodeProvider):
|
|||
return base + str(self._next_node_id).zfill(5)
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
nodes = []
|
||||
for node_id in self._nodes:
|
||||
tags = self.node_tags(node_id)
|
||||
ok = True
|
||||
for k, v in tag_filters.items():
|
||||
if tags.get(k) != v:
|
||||
ok = False
|
||||
if ok:
|
||||
nodes.append(node_id)
|
||||
return nodes
|
||||
with self.lock:
|
||||
nodes = []
|
||||
for node_id in self._nodes:
|
||||
tags = self.node_tags(node_id)
|
||||
ok = True
|
||||
for k, v in tag_filters.items():
|
||||
if tags.get(k) != v:
|
||||
ok = False
|
||||
if ok:
|
||||
nodes.append(node_id)
|
||||
return nodes
|
||||
|
||||
def is_running(self, node_id):
|
||||
return node_id in self._nodes
|
||||
with self.lock:
|
||||
return node_id in self._nodes
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
return node_id not in self._nodes
|
||||
with self.lock:
|
||||
return node_id not in self._nodes
|
||||
|
||||
def node_tags(self, node_id):
|
||||
return self._nodes[node_id]["tags"]
|
||||
with self.lock:
|
||||
return self._nodes[node_id]["tags"]
|
||||
|
||||
def external_ip(self, node_id):
|
||||
return node_id
|
||||
|
@ -77,37 +83,42 @@ class FakeMultiNodeProvider(NodeProvider):
|
|||
raise AssertionError("Readonly node provider cannot be updated")
|
||||
|
||||
def create_node_with_resources(self, node_config, tags, count, resources):
|
||||
node_type = tags[TAG_RAY_USER_NODE_TYPE]
|
||||
next_id = self._next_hex_node_id()
|
||||
ray_params = ray._private.parameter.RayParams(
|
||||
min_worker_port=0,
|
||||
max_worker_port=0,
|
||||
dashboard_port=None,
|
||||
num_cpus=resources.pop("CPU", 0),
|
||||
num_gpus=resources.pop("GPU", 0),
|
||||
object_store_memory=resources.pop("object_store_memory", None),
|
||||
resources=resources,
|
||||
redis_address="{}:6379".format(
|
||||
ray._private.services.get_node_ip_address()),
|
||||
env_vars={
|
||||
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
|
||||
"RAY_OVERRIDE_RESOURCES": json.dumps(resources),
|
||||
})
|
||||
node = ray.node.Node(
|
||||
ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False)
|
||||
self._nodes[next_id] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
|
||||
TAG_RAY_USER_NODE_TYPE: node_type,
|
||||
TAG_RAY_NODE_NAME: next_id,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
},
|
||||
"node": node
|
||||
}
|
||||
with self.lock:
|
||||
node_type = tags[TAG_RAY_USER_NODE_TYPE]
|
||||
next_id = self._next_hex_node_id()
|
||||
ray_params = ray._private.parameter.RayParams(
|
||||
min_worker_port=0,
|
||||
max_worker_port=0,
|
||||
dashboard_port=None,
|
||||
num_cpus=resources.pop("CPU", 0),
|
||||
num_gpus=resources.pop("GPU", 0),
|
||||
object_store_memory=resources.pop("object_store_memory", None),
|
||||
resources=resources,
|
||||
redis_address="{}:6379".format(
|
||||
ray._private.services.get_node_ip_address()),
|
||||
env_vars={
|
||||
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
|
||||
"RAY_OVERRIDE_RESOURCES": json.dumps(resources),
|
||||
})
|
||||
node = ray.node.Node(
|
||||
ray_params,
|
||||
head=False,
|
||||
shutdown_at_exit=False,
|
||||
spawn_reaper=False)
|
||||
self._nodes[next_id] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
|
||||
TAG_RAY_USER_NODE_TYPE: node_type,
|
||||
TAG_RAY_NODE_NAME: next_id,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
},
|
||||
"node": node
|
||||
}
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
node = self._nodes.pop(node_id)["node"]
|
||||
self._kill_ray_processes(node)
|
||||
with self.lock:
|
||||
node = self._nodes.pop(node_id)["node"]
|
||||
self._kill_ray_processes(node)
|
||||
|
||||
def _kill_ray_processes(self, node):
|
||||
node.kill_all_processes(check_alive=False, allow_graceful=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue