"""Implementation of the CART algorithm to train decision tree classifiers.""" import numpy as np import ray from sklearn import datasets, metrics import time import tempfile import os import json """Binary tree with decision tree semantics and ASCII visualization.""" class Node: """A decision tree node.""" def __init__(self, gini, num_samples, num_samples_per_class, predicted_class): self.gini = gini self.num_samples = num_samples self.num_samples_per_class = num_samples_per_class self.predicted_class = predicted_class self.feature_index = 0 self.threshold = 0 self.left = None self.right = None def debug(self, feature_names, class_names, show_details): """Print an ASCII visualization of the tree.""" lines, _, _, _ = self._debug_aux( feature_names, class_names, show_details, root=True) for line in lines: print(line) def _debug_aux(self, feature_names, class_names, show_details, root=False): # See https://stackoverflow.com/a/54074933/1143396 for similar code. is_leaf = not self.right if is_leaf: lines = [class_names[self.predicted_class]] else: lines = [ "{} < {:.2f}".format(feature_names[self.feature_index], self.threshold) ] if show_details: lines += [ "gini = {:.2f}".format(self.gini), "samples = {}".format(self.num_samples), str(self.num_samples_per_class), ] width = max(len(line) for line in lines) height = len(lines) if is_leaf: lines = [ "║ {:^{width}} ║".format(line, width=width) for line in lines ] lines.insert(0, "╔" + "═" * (width + 2) + "╗") lines.append("╚" + "═" * (width + 2) + "╝") else: lines = [ "│ {:^{width}} │".format(line, width=width) for line in lines ] lines.insert(0, "┌" + "─" * (width + 2) + "┐") lines.append("└" + "─" * (width + 2) + "┘") lines[-2] = "┤" + lines[-2][1:-1] + "├" width += 4 # for padding if is_leaf: middle = width // 2 lines[0] = lines[0][:middle] + "╧" + lines[0][middle + 1:] return lines, width, height, middle # If not a leaf, must have two children. left, n, p, x = self.left._debug_aux(feature_names, class_names, show_details) right, m, q, y = self.right._debug_aux(feature_names, class_names, show_details) top_lines = [n * " " + line + m * " " for line in lines[:-2]] # fmt: off middle_line = x * " " + "┌" + ( n - x - 1) * "─" + lines[-2] + y * "─" + "┐" + (m - y - 1) * " " bottom_line = x * " " + "│" + ( n - x - 1) * " " + lines[-1] + y * " " + "│" + (m - y - 1) * " " # fmt: on if p < q: left += [n * " "] * (q - p) elif q < p: right += [m * " "] * (p - q) zipped_lines = zip(left, right) lines = (top_lines + [middle_line, bottom_line] + [a + width * " " + b for a, b in zipped_lines]) middle = n + width // 2 if not root: lines[0] = lines[0][:middle] + "┴" + lines[0][middle + 1:] return lines, n + m + width, max(p, q) + 2 + len(top_lines), middle class DecisionTreeClassifier: def __init__(self, max_depth=None, tree_limit=5000, feature_limit=2000): self.max_depth = max_depth self.tree_limit = tree_limit self.feature_limit = feature_limit def fit(self, X, y): """Build decision tree classifier.""" self.n_classes_ = len( set(y)) # classes are assumed to go from 0 to n-1 self.n_features_ = X.shape[1] self.tree_ = self._grow_tree(X, y) def predict(self, X): """Predict class for X.""" return [self._predict(inputs) for inputs in X] def debug(self, feature_names, class_names, show_details=True): """Print ASCII visualization of decision tree.""" self.tree_.debug(feature_names, class_names, show_details) def _gini(self, y): """Compute Gini impurity of a non-empty node. Gini impurity is defined as Σ p(1-p) over all classes, with p the freq class within the node. Since Σ p = 1, this is equivalent to 1 - Σ p^2. """ m = y.size return 1.0 - sum( (np.sum(y == c) / m)**2 for c in range(self.n_classes_)) def _best_split(self, X, y): return best_split(self, X, y) def _grow_tree(self, X, y, depth=0): future = grow_tree_remote.remote(self, X, y, depth) return ray.get(future) def _predict(self, inputs): """Predict class for a single sample.""" node = self.tree_ while node.left: if inputs[node.feature_index] < node.threshold: node = node.left else: node = node.right return node.predicted_class def grow_tree_local(tree, X, y, depth): """Build a decision tree by recursively finding the best split.""" # Population for each class in current node. The predicted class is the one # largest population. num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)] predicted_class = np.argmax(num_samples_per_class) node = Node( gini=tree._gini(y), num_samples=y.size, num_samples_per_class=num_samples_per_class, predicted_class=predicted_class, ) # Split recursively until maximum depth is reached. if depth < tree.max_depth: idx, thr = tree._best_split(X, y) if idx is not None: indices_left = X[:, idx] < thr X_left, y_left = X[indices_left], y[indices_left] X_right, y_right = X[~indices_left], y[~indices_left] node.feature_index = idx node.threshold = thr node.left = grow_tree_local(tree, X_left, y_left, depth + 1) node.right = grow_tree_local(tree, X_right, y_right, depth + 1) return node @ray.remote def grow_tree_remote(tree, X, y, depth=0): """Build a decision tree by recursively finding the best split.""" # Population for each class in current node. The predicted class is the one # largest population. num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)] predicted_class = np.argmax(num_samples_per_class) node = Node( gini=tree._gini(y), num_samples=y.size, num_samples_per_class=num_samples_per_class, predicted_class=predicted_class, ) # Split recursively until maximum depth is reached. if depth < tree.max_depth: idx, thr = tree._best_split(X, y) if idx is not None: indices_left = X[:, idx] < thr X_left, y_left = X[indices_left], y[indices_left] X_right, y_right = X[~indices_left], y[~indices_left] node.feature_index = idx node.threshold = thr if (len(X_left) > tree.tree_limit or len(X_right) > tree.tree_limit): left_future = grow_tree_remote.remote(tree, X_left, y_left, depth + 1) right_future = grow_tree_remote.remote(tree, X_right, y_right, depth + 1) node.left = ray.get(left_future) node.right = ray.get(right_future) else: node.left = grow_tree_local(tree, X_left, y_left, depth + 1) node.right = grow_tree_local(tree, X_right, y_right, depth + 1) return node def best_split_original(tree, X, y): """Find the best split for a node.""" # Need at least two elements to split a node. m = y.size if m <= 1: return None, None # Count of each class in the current node. num_parent = [np.sum(y == c) for c in range(tree.n_classes_)] # Gini of current node. best_gini = 1.0 - sum((n / m)**2 for n in num_parent) best_idx, best_thr = None, None # Loop through all features. for idx in range(tree.n_features_): # Sort data along selected feature. thresholds, classes = zip(*sorted(zip(X[:, idx], y))) # print("Classes are: ", classes, " ", thresholds) # We could actually split the node according to each feature/threshold # and count the resulting population for each class in the children, # instead we compute them in an iterative fashion, making this for loop # linear rather than quadratic. num_left = [0] * tree.n_classes_ num_right = num_parent.copy() for i in range(1, m): # possible split positions c = classes[i - 1] # print("c is ", c, "num left is", len(num_left)) num_left[c] += 1 num_right[c] -= 1 gini_left = 1.0 - sum( (num_left[x] / i)**2 for x in range(tree.n_classes_)) gini_right = 1.0 - sum( (num_right[x] / (m - i))**2 for x in range(tree.n_classes_)) # The Gini impurity of a split is the weighted average of the Gini # impurity of the children. gini = (i * gini_left + (m - i) * gini_right) / m # The following condition is to make sure we don't try to split two # points with identical values for that feature, as it is impossibl # (both have to end up on the same side of a split). if thresholds[i] == thresholds[i - 1]: continue if gini < best_gini: best_gini = gini best_idx = idx best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint return best_idx, best_thr def best_split_for_idx(tree, idx, X, y, num_parent, best_gini): """Find the best split for a node and a given index """ # Sort data along selected feature. thresholds, classes = zip(*sorted(zip(X[:, idx], y))) # print("Classes are: ", classes, " ", thresholds) # We could actually split the node according to each feature/threshold pair # and count the resulting population for each class in the children, but # instead we compute them in an iterative fashion, making this for loop # linear rather than quadratic. m = y.size num_left = [0] * tree.n_classes_ num_right = num_parent.copy() best_thr = float("NaN") for i in range(1, m): # possible split positions c = classes[i - 1] # print("c is ", c, "num left is", len(num_left)) num_left[c] += 1 num_right[c] -= 1 gini_left = 1.0 - sum( (num_left[x] / i)**2 for x in range(tree.n_classes_)) gini_right = 1.0 - sum( (num_right[x] / (m - i))**2 for x in range(tree.n_classes_)) # The Gini impurity of a split is the weighted average of the Gini # impurity of the children. gini = (i * gini_left + (m - i) * gini_right) / m # The following condition is to make sure we don't try to split two # points with identical values for that feature, as it is impossible # (both have to end up on the same side of a split). if thresholds[i] == thresholds[i - 1]: continue if gini < best_gini: best_gini = gini best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint return best_gini, best_thr @ray.remote def best_split_for_idx_remote(tree, idx, X, y, num_parent, best_gini): return best_split_for_idx(tree, idx, X, y, num_parent, best_gini) def best_split(tree, X, y): """Find the best split for a node.""" # Need at least two elements to split a node. m = y.size if m <= 1: return None, None # Count of each class in the current node. num_parent = [np.sum(y == c) for c in range(tree.n_classes_)] # Gini of current node. best_gini = 1.0 - sum((n / m)**2 for n in num_parent) best_idx, best_thr = -1, best_gini if (m > tree.feature_limit): split_futures = [ best_split_for_idx_remote.remote(tree, i, X, y, num_parent, best_gini) for i in range(tree.n_features_) ] best_splits = [ray.get(result) for result in split_futures] else: best_splits = [ best_split_for_idx(tree, i, X, y, num_parent, best_gini) for i in range(tree.n_features_) ] ginis = np.array([x for (x, _) in best_splits]) best_idx = np.argmin(ginis) best_thr = best_splits[best_idx][1] return best_idx, best_thr @ray.remote def run_in_cluster(): dataset = datasets.fetch_covtype(data_home=tempfile.mkdtemp()) X, y = dataset.data, dataset.target - 1 training_size = 400000 max_depth = 10 clf = DecisionTreeClassifier(max_depth=max_depth) start = time.time() clf.fit(X[:training_size], y[:training_size]) end = time.time() y_pred = clf.predict(X[training_size:]) accuracy = metrics.accuracy_score(y[training_size:], y_pred) return end - start, accuracy if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--concurrency", type=int, default=1) args = parser.parse_args() ray.init(address=os.environ["RAY_ADDRESS"]) futures = [] for i in range(args.concurrency): print(f"concurrent run: {i}") futures.append(run_in_cluster.remote()) time.sleep(10) for i, f in enumerate(futures): treetime, accuracy = ray.get(f) print(f"Tree {i} building took {treetime} seconds") print(f"Test Accuracy: {accuracy}") with open(os.environ["TEST_OUTPUT_JSON"], "w") as f: f.write(json.dumps({"build_time": treetime, "success": 1}))