mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Autoscaler] Reload config (#10450)
This commit is contained in:
parent
1dd55f4b07
commit
23bbe0f36a
4 changed files with 97 additions and 60 deletions
|
@ -58,20 +58,8 @@ class StandardAutoscaler:
|
|||
process_runner=subprocess,
|
||||
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
|
||||
self.config_path = config_path
|
||||
self.reload_config(errors_fatal=True)
|
||||
self.reset(errors_fatal=True)
|
||||
self.load_metrics = load_metrics
|
||||
self.provider = get_node_provider(self.config["provider"],
|
||||
self.config["cluster_name"])
|
||||
|
||||
# Check whether we can enable the resource demand scheduler.
|
||||
if "available_node_types" in self.config:
|
||||
self.available_node_types = self.config["available_node_types"]
|
||||
self.resource_demand_scheduler = ResourceDemandScheduler(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"])
|
||||
else:
|
||||
self.available_node_types = None
|
||||
self.resource_demand_scheduler = None
|
||||
|
||||
self.max_failures = max_failures
|
||||
self.max_launch_batch = max_launch_batch
|
||||
|
@ -123,7 +111,7 @@ class StandardAutoscaler:
|
|||
|
||||
def update(self):
|
||||
try:
|
||||
self.reload_config(errors_fatal=False)
|
||||
self.reset(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
logger.exception("StandardAutoscaler: "
|
||||
|
@ -274,7 +262,7 @@ class StandardAutoscaler:
|
|||
else:
|
||||
return {}
|
||||
|
||||
def reload_config(self, errors_fatal=False):
|
||||
def reset(self, errors_fatal=False):
|
||||
sync_continuously = False
|
||||
if hasattr(self, "config"):
|
||||
sync_continuously = self.config.get(
|
||||
|
@ -283,8 +271,6 @@ class StandardAutoscaler:
|
|||
with open(self.config_path) as f:
|
||||
new_config = yaml.safe_load(f.read())
|
||||
validate_config(new_config)
|
||||
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
|
||||
new_config["auth"])
|
||||
(new_runtime_hash,
|
||||
new_file_mounts_contents_hash) = hash_runtime_conf(
|
||||
new_config["file_mounts"],
|
||||
|
@ -296,9 +282,21 @@ class StandardAutoscaler:
|
|||
generate_file_mounts_contents_hash=sync_continuously,
|
||||
)
|
||||
self.config = new_config
|
||||
self.launch_hash = new_launch_hash
|
||||
self.runtime_hash = new_runtime_hash
|
||||
self.file_mounts_contents_hash = new_file_mounts_contents_hash
|
||||
|
||||
self.provider = get_node_provider(self.config["provider"],
|
||||
self.config["cluster_name"])
|
||||
# Check whether we can enable the resource demand scheduler.
|
||||
if "available_node_types" in self.config:
|
||||
self.available_node_types = self.config["available_node_types"]
|
||||
self.resource_demand_scheduler = ResourceDemandScheduler(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"])
|
||||
else:
|
||||
self.available_node_types = None
|
||||
self.resource_demand_scheduler = None
|
||||
|
||||
except Exception as e:
|
||||
if errors_fatal:
|
||||
raise e
|
||||
|
@ -338,9 +336,18 @@ class StandardAutoscaler:
|
|||
max(self.config["min_workers"], ideal_num_workers))
|
||||
|
||||
def launch_config_ok(self, node_id):
|
||||
launch_conf = self.provider.node_tags(node_id).get(
|
||||
TAG_RAY_LAUNCH_CONFIG)
|
||||
if self.launch_hash != launch_conf:
|
||||
node_tags = self.provider.node_tags(node_id)
|
||||
tag_launch_conf = node_tags.get(TAG_RAY_LAUNCH_CONFIG)
|
||||
node_type = node_tags.get(TAG_RAY_USER_NODE_TYPE)
|
||||
|
||||
launch_config = copy.deepcopy(self.config["worker_nodes"])
|
||||
if node_type:
|
||||
launch_config.update(
|
||||
self.config["available_node_types"][node_type]["node_config"])
|
||||
calculated_launch_hash = hash_launch_conf(launch_config,
|
||||
self.config["auth"])
|
||||
|
||||
if calculated_launch_hash != tag_launch_conf:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
|
@ -36,7 +36,12 @@ class NodeLauncher(threading.Thread):
|
|||
assert node_type, node_type
|
||||
worker_filter = {TAG_RAY_NODE_KIND: NODE_KIND_WORKER}
|
||||
before = self.provider.non_terminated_nodes(tag_filters=worker_filter)
|
||||
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
|
||||
|
||||
launch_config = copy.deepcopy(config["worker_nodes"])
|
||||
if node_type:
|
||||
launch_config.update(
|
||||
config["available_node_types"][node_type]["node_config"])
|
||||
launch_hash = hash_launch_conf(launch_config, config["auth"])
|
||||
self.log("Launching {} nodes, type {}.".format(count, node_type))
|
||||
node_config = copy.deepcopy(config["worker_nodes"])
|
||||
node_tags = {
|
||||
|
@ -51,8 +56,7 @@ class NodeLauncher(threading.Thread):
|
|||
# TODO(ekl) this logic is duplicated in commands.py (keep in sync)
|
||||
if node_type:
|
||||
node_tags[TAG_RAY_USER_NODE_TYPE] = node_type
|
||||
node_config.update(
|
||||
config["available_node_types"][node_type]["node_config"])
|
||||
node_config.update(launch_config)
|
||||
self.provider.create_node(node_config, node_tags, count)
|
||||
after = self.provider.non_terminated_nodes(tag_filters=worker_filter)
|
||||
if set(after).issubset(before):
|
||||
|
|
|
@ -121,24 +121,30 @@ class MockProvider(NodeProvider):
|
|||
self.ready_to_create = threading.Event()
|
||||
self.ready_to_create.set()
|
||||
self.cache_stopped = cache_stopped
|
||||
# Many of these functions are called by node_launcher or updater in
|
||||
# different threads. This can be treated as a global lock for
|
||||
# everything.
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
if self.throw:
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.node_id for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
with self.lock:
|
||||
if self.throw:
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.node_id for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
|
||||
def non_terminated_node_ips(self, tag_filters):
|
||||
if self.throw:
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.internal_ip for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
with self.lock:
|
||||
if self.throw:
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.internal_ip for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
|
||||
def is_running(self, node_id):
|
||||
return self.mock_nodes[node_id].state == "running"
|
||||
|
@ -159,31 +165,34 @@ class MockProvider(NodeProvider):
|
|||
self.ready_to_create.wait()
|
||||
if self.fail_creates:
|
||||
return
|
||||
if self.cache_stopped:
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "stopped" and count > 0:
|
||||
count -= 1
|
||||
node.state = "pending"
|
||||
node.tags.update(tags)
|
||||
for _ in range(count):
|
||||
self.mock_nodes[self.next_id] = MockNode(
|
||||
self.next_id, tags.copy(), node_config,
|
||||
tags.get(TAG_RAY_USER_NODE_TYPE))
|
||||
self.next_id += 1
|
||||
with self.lock:
|
||||
if self.cache_stopped:
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "stopped" and count > 0:
|
||||
count -= 1
|
||||
node.state = "pending"
|
||||
node.tags.update(tags)
|
||||
for _ in range(count):
|
||||
self.mock_nodes[self.next_id] = MockNode(
|
||||
self.next_id, tags.copy(), node_config,
|
||||
tags.get(TAG_RAY_USER_NODE_TYPE))
|
||||
self.next_id += 1
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
self.mock_nodes[node_id].tags.update(tags)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
if self.cache_stopped:
|
||||
self.mock_nodes[node_id].state = "stopped"
|
||||
else:
|
||||
self.mock_nodes[node_id].state = "terminated"
|
||||
with self.lock:
|
||||
if self.cache_stopped:
|
||||
self.mock_nodes[node_id].state = "stopped"
|
||||
else:
|
||||
self.mock_nodes[node_id].state = "terminated"
|
||||
|
||||
def finish_starting_nodes(self):
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "pending":
|
||||
node.state = "running"
|
||||
with self.lock:
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "pending":
|
||||
node.state = "running"
|
||||
|
||||
|
||||
SMALL_CLUSTER = {
|
||||
|
|
|
@ -219,10 +219,6 @@ class AutoscalingTest(unittest.TestCase):
|
|||
return path
|
||||
|
||||
def testGetOrCreateMultiNodeType(self):
|
||||
config = MULTI_WORKER_CLUSTER.copy()
|
||||
# Commenting out this line causes the test case to fail?!?!
|
||||
config["min_workers"] = 0
|
||||
config_path = self.write_config(config)
|
||||
config_path = self.write_config(MULTI_WORKER_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
|
@ -453,6 +449,27 @@ class AutoscalingTest(unittest.TestCase):
|
|||
runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip,
|
||||
"init_cmd")
|
||||
|
||||
def testUpdateConfig(self):
|
||||
config = MULTI_WORKER_CLUSTER.copy()
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
assert len(self.provider.non_terminated_nodes({})) == 0
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
config["min_workers"] = 0
|
||||
config["available_node_types"]["m4.large"]["node_config"][
|
||||
"field_changed"] = 1
|
||||
config_path = self.write_config(config)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
Loading…
Add table
Reference in a new issue