[Autoscaler] Reload config (#10450)

This commit is contained in:
Alex Wu 2020-09-01 14:37:04 -07:00 committed by GitHub
parent 1dd55f4b07
commit 23bbe0f36a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 60 deletions

View file

@ -58,20 +58,8 @@ class StandardAutoscaler:
process_runner=subprocess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.reset(errors_fatal=True)
self.load_metrics = load_metrics
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
# Check whether we can enable the resource demand scheduler.
if "available_node_types" in self.config:
self.available_node_types = self.config["available_node_types"]
self.resource_demand_scheduler = ResourceDemandScheduler(
self.provider, self.available_node_types,
self.config["max_workers"])
else:
self.available_node_types = None
self.resource_demand_scheduler = None
self.max_failures = max_failures
self.max_launch_batch = max_launch_batch
@ -123,7 +111,7 @@ class StandardAutoscaler:
def update(self):
try:
self.reload_config(errors_fatal=False)
self.reset(errors_fatal=False)
self._update()
except Exception as e:
logger.exception("StandardAutoscaler: "
@ -274,7 +262,7 @@ class StandardAutoscaler:
else:
return {}
def reload_config(self, errors_fatal=False):
def reset(self, errors_fatal=False):
sync_continuously = False
if hasattr(self, "config"):
sync_continuously = self.config.get(
@ -283,8 +271,6 @@ class StandardAutoscaler:
with open(self.config_path) as f:
new_config = yaml.safe_load(f.read())
validate_config(new_config)
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
new_config["auth"])
(new_runtime_hash,
new_file_mounts_contents_hash) = hash_runtime_conf(
new_config["file_mounts"],
@ -296,9 +282,21 @@ class StandardAutoscaler:
generate_file_mounts_contents_hash=sync_continuously,
)
self.config = new_config
self.launch_hash = new_launch_hash
self.runtime_hash = new_runtime_hash
self.file_mounts_contents_hash = new_file_mounts_contents_hash
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
# Check whether we can enable the resource demand scheduler.
if "available_node_types" in self.config:
self.available_node_types = self.config["available_node_types"]
self.resource_demand_scheduler = ResourceDemandScheduler(
self.provider, self.available_node_types,
self.config["max_workers"])
else:
self.available_node_types = None
self.resource_demand_scheduler = None
except Exception as e:
if errors_fatal:
raise e
@ -338,9 +336,18 @@ class StandardAutoscaler:
max(self.config["min_workers"], ideal_num_workers))
def launch_config_ok(self, node_id):
launch_conf = self.provider.node_tags(node_id).get(
TAG_RAY_LAUNCH_CONFIG)
if self.launch_hash != launch_conf:
node_tags = self.provider.node_tags(node_id)
tag_launch_conf = node_tags.get(TAG_RAY_LAUNCH_CONFIG)
node_type = node_tags.get(TAG_RAY_USER_NODE_TYPE)
launch_config = copy.deepcopy(self.config["worker_nodes"])
if node_type:
launch_config.update(
self.config["available_node_types"][node_type]["node_config"])
calculated_launch_hash = hash_launch_conf(launch_config,
self.config["auth"])
if calculated_launch_hash != tag_launch_conf:
return False
return True

View file

@ -36,7 +36,12 @@ class NodeLauncher(threading.Thread):
assert node_type, node_type
worker_filter = {TAG_RAY_NODE_KIND: NODE_KIND_WORKER}
before = self.provider.non_terminated_nodes(tag_filters=worker_filter)
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
launch_config = copy.deepcopy(config["worker_nodes"])
if node_type:
launch_config.update(
config["available_node_types"][node_type]["node_config"])
launch_hash = hash_launch_conf(launch_config, config["auth"])
self.log("Launching {} nodes, type {}.".format(count, node_type))
node_config = copy.deepcopy(config["worker_nodes"])
node_tags = {
@ -51,8 +56,7 @@ class NodeLauncher(threading.Thread):
# TODO(ekl) this logic is duplicated in commands.py (keep in sync)
if node_type:
node_tags[TAG_RAY_USER_NODE_TYPE] = node_type
node_config.update(
config["available_node_types"][node_type]["node_config"])
node_config.update(launch_config)
self.provider.create_node(node_config, node_tags, count)
after = self.provider.non_terminated_nodes(tag_filters=worker_filter)
if set(after).issubset(before):

View file

@ -121,24 +121,30 @@ class MockProvider(NodeProvider):
self.ready_to_create = threading.Event()
self.ready_to_create.set()
self.cache_stopped = cache_stopped
# Many of these functions are called by node_launcher or updater in
# different threads. This can be treated as a global lock for
# everything.
self.lock = threading.Lock()
def non_terminated_nodes(self, tag_filters):
if self.throw:
raise Exception("oops")
return [
n.node_id for n in self.mock_nodes.values()
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
with self.lock:
if self.throw:
raise Exception("oops")
return [
n.node_id for n in self.mock_nodes.values()
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
def non_terminated_node_ips(self, tag_filters):
if self.throw:
raise Exception("oops")
return [
n.internal_ip for n in self.mock_nodes.values()
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
with self.lock:
if self.throw:
raise Exception("oops")
return [
n.internal_ip for n in self.mock_nodes.values()
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
def is_running(self, node_id):
return self.mock_nodes[node_id].state == "running"
@ -159,31 +165,34 @@ class MockProvider(NodeProvider):
self.ready_to_create.wait()
if self.fail_creates:
return
if self.cache_stopped:
for node in self.mock_nodes.values():
if node.state == "stopped" and count > 0:
count -= 1
node.state = "pending"
node.tags.update(tags)
for _ in range(count):
self.mock_nodes[self.next_id] = MockNode(
self.next_id, tags.copy(), node_config,
tags.get(TAG_RAY_USER_NODE_TYPE))
self.next_id += 1
with self.lock:
if self.cache_stopped:
for node in self.mock_nodes.values():
if node.state == "stopped" and count > 0:
count -= 1
node.state = "pending"
node.tags.update(tags)
for _ in range(count):
self.mock_nodes[self.next_id] = MockNode(
self.next_id, tags.copy(), node_config,
tags.get(TAG_RAY_USER_NODE_TYPE))
self.next_id += 1
def set_node_tags(self, node_id, tags):
self.mock_nodes[node_id].tags.update(tags)
def terminate_node(self, node_id):
if self.cache_stopped:
self.mock_nodes[node_id].state = "stopped"
else:
self.mock_nodes[node_id].state = "terminated"
with self.lock:
if self.cache_stopped:
self.mock_nodes[node_id].state = "stopped"
else:
self.mock_nodes[node_id].state = "terminated"
def finish_starting_nodes(self):
for node in self.mock_nodes.values():
if node.state == "pending":
node.state = "running"
with self.lock:
for node in self.mock_nodes.values():
if node.state == "pending":
node.state = "running"
SMALL_CLUSTER = {

View file

@ -219,10 +219,6 @@ class AutoscalingTest(unittest.TestCase):
return path
def testGetOrCreateMultiNodeType(self):
config = MULTI_WORKER_CLUSTER.copy()
# Commenting out this line causes the test case to fail?!?!
config["min_workers"] = 0
config_path = self.write_config(config)
config_path = self.write_config(MULTI_WORKER_CLUSTER)
self.provider = MockProvider()
runner = MockProcessRunner()
@ -453,6 +449,27 @@ class AutoscalingTest(unittest.TestCase):
runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip,
"init_cmd")
def testUpdateConfig(self):
config = MULTI_WORKER_CLUSTER.copy()
config_path = self.write_config(config)
self.provider = MockProvider()
runner = MockProcessRunner()
autoscaler = StandardAutoscaler(
config_path,
LoadMetrics(),
max_failures=0,
process_runner=runner,
update_interval_s=0)
assert len(self.provider.non_terminated_nodes({})) == 0
autoscaler.update()
self.waitForNodes(2)
config["min_workers"] = 0
config["available_node_types"]["m4.large"]["node_config"][
"field_changed"] = 1
config_path = self.write_config(config)
autoscaler.update()
self.waitForNodes(0)
if __name__ == "__main__":
import sys