Make fake node provider thread safe (#20591)

We may have multiple NodeLauncher threads access the same node provider so it should be thread safe.
This commit is contained in:
Jiajun Yao 2021-11-19 18:59:38 -08:00 committed by GitHub
parent 12c11894e8
commit 255bdc8fb1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,7 @@
import logging
import os
import json
from threading import RLock
import ray
from ray.autoscaler.node_provider import NodeProvider
@ -25,6 +26,7 @@ class FakeMultiNodeProvider(NodeProvider):
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
self.lock = RLock()
if "RAY_FAKE_CLUSTER" not in os.environ:
raise RuntimeError(
"FakeMultiNodeProvider requires ray to be started with "
@ -47,25 +49,29 @@ class FakeMultiNodeProvider(NodeProvider):
return base + str(self._next_node_id).zfill(5)
def non_terminated_nodes(self, tag_filters):
nodes = []
for node_id in self._nodes:
tags = self.node_tags(node_id)
ok = True
for k, v in tag_filters.items():
if tags.get(k) != v:
ok = False
if ok:
nodes.append(node_id)
return nodes
with self.lock:
nodes = []
for node_id in self._nodes:
tags = self.node_tags(node_id)
ok = True
for k, v in tag_filters.items():
if tags.get(k) != v:
ok = False
if ok:
nodes.append(node_id)
return nodes
def is_running(self, node_id):
return node_id in self._nodes
with self.lock:
return node_id in self._nodes
def is_terminated(self, node_id):
return node_id not in self._nodes
with self.lock:
return node_id not in self._nodes
def node_tags(self, node_id):
return self._nodes[node_id]["tags"]
with self.lock:
return self._nodes[node_id]["tags"]
def external_ip(self, node_id):
return node_id
@ -77,37 +83,42 @@ class FakeMultiNodeProvider(NodeProvider):
raise AssertionError("Readonly node provider cannot be updated")
def create_node_with_resources(self, node_config, tags, count, resources):
node_type = tags[TAG_RAY_USER_NODE_TYPE]
next_id = self._next_hex_node_id()
ray_params = ray._private.parameter.RayParams(
min_worker_port=0,
max_worker_port=0,
dashboard_port=None,
num_cpus=resources.pop("CPU", 0),
num_gpus=resources.pop("GPU", 0),
object_store_memory=resources.pop("object_store_memory", None),
resources=resources,
redis_address="{}:6379".format(
ray._private.services.get_node_ip_address()),
env_vars={
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
"RAY_OVERRIDE_RESOURCES": json.dumps(resources),
})
node = ray.node.Node(
ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False)
self._nodes[next_id] = {
"tags": {
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
TAG_RAY_USER_NODE_TYPE: node_type,
TAG_RAY_NODE_NAME: next_id,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
},
"node": node
}
with self.lock:
node_type = tags[TAG_RAY_USER_NODE_TYPE]
next_id = self._next_hex_node_id()
ray_params = ray._private.parameter.RayParams(
min_worker_port=0,
max_worker_port=0,
dashboard_port=None,
num_cpus=resources.pop("CPU", 0),
num_gpus=resources.pop("GPU", 0),
object_store_memory=resources.pop("object_store_memory", None),
resources=resources,
redis_address="{}:6379".format(
ray._private.services.get_node_ip_address()),
env_vars={
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
"RAY_OVERRIDE_RESOURCES": json.dumps(resources),
})
node = ray.node.Node(
ray_params,
head=False,
shutdown_at_exit=False,
spawn_reaper=False)
self._nodes[next_id] = {
"tags": {
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
TAG_RAY_USER_NODE_TYPE: node_type,
TAG_RAY_NODE_NAME: next_id,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
},
"node": node
}
def terminate_node(self, node_id):
node = self._nodes.pop(node_id)["node"]
self._kill_ray_processes(node)
with self.lock:
node = self._nodes.pop(node_id)["node"]
self._kill_ray_processes(node)
def _kill_ray_processes(self, node):
node.kill_all_processes(check_alive=False, allow_graceful=True)