mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Enable test_joblib in CI (#7404)
This commit is contained in:
parent
d69fe54f6d
commit
2b6f00724a
1 changed files with 12 additions and 4 deletions
|
@ -1,8 +1,11 @@
|
|||
import numpy as np
|
||||
import joblib
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sklearn.datasets import load_digits, load_iris
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from time import time
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.ensemble import ExtraTreesClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
@ -142,9 +145,9 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
|
|||
|
||||
if "n_jobs" in estimator_params:
|
||||
estimator.set_params(n_jobs=num_jobs)
|
||||
time_start = time()
|
||||
time_start = time.time()
|
||||
estimator.fit(X_train, y_train)
|
||||
train_time[name] = time() - time_start
|
||||
train_time[name] = time.time() - time_start
|
||||
print("training", name, "took", train_time[name], "seconds")
|
||||
|
||||
|
||||
|
@ -157,3 +160,8 @@ def test_cross_validation(shutdown_only):
|
|||
assert len(accuracy) == 5
|
||||
for result in accuracy:
|
||||
assert result > 0.95
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue