mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[joblib] Fix flaky joblib test. (#13046)
This commit is contained in:
parent
1e74187179
commit
d37e2c3a20
2 changed files with 7 additions and 13 deletions
BIN
python/ray/tests/mnist_784_100_samples.pkl
Normal file
BIN
python/ray/tests/mnist_784_100_samples.pkl
Normal file
Binary file not shown.
|
@ -1,12 +1,13 @@
|
|||
import joblib
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from sklearn.datasets import load_digits, load_iris
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.ensemble import ExtraTreesClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.kernel_approximation import Nystroem
|
||||
|
@ -14,7 +15,6 @@ from sklearn.kernel_approximation import RBFSampler
|
|||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.svm import LinearSVC, SVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.utils import check_array
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
@ -112,20 +112,14 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
|
|||
}
|
||||
# Load dataset.
|
||||
print("Loading dataset...")
|
||||
data = fetch_openml("mnist_784")
|
||||
X = check_array(data["data"], dtype=np.float32, order="C")
|
||||
y = data["target"]
|
||||
|
||||
unnormalized_X_train, y_train = pickle.load(
|
||||
open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "mnist_784_100_samples.pkl"), "rb"))
|
||||
# Normalize features.
|
||||
X = X / 255
|
||||
X_train = unnormalized_X_train / 255
|
||||
|
||||
# Create train-test split.
|
||||
print("Creating train-test split...")
|
||||
n_train = 100
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
register_ray()
|
||||
|
||||
train_time = {}
|
||||
random_seed = 0
|
||||
# Use two workers per classifier.
|
||||
|
|
Loading…
Add table
Reference in a new issue