mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Support of scikit-learn with ray joblib backend (#6925)
This commit is contained in:
parent
396d7fafc8
commit
a7ecda6017
5 changed files with 311 additions and 0 deletions
69
doc/source/joblib.rst
Normal file
69
doc/source/joblib.rst
Normal file
|
@ -0,0 +1,69 @@
|
|||
sklearn Ray Backend API (Experimental)
|
||||
=======================================
|
||||
|
||||
.. warning::
|
||||
|
||||
Support for running scikit-learn on Ray is an experimental feature,
|
||||
so it may be changed at any time without warning. If you encounter any
|
||||
bugs/shortcomings/incompatibilities, please file an `issue on GitHub`_.
|
||||
Contributions are always welcome!
|
||||
|
||||
.. _`issue on GitHub`: https://github.com/ray-project/ray/issues
|
||||
|
||||
Ray supports running distributed `scikit-learn`_ programs by
|
||||
implementing a Ray backend for `joblib`_ using `Ray Actors <actors.html>`__
|
||||
instead of local processes. This makes it easy to scale existing applications
|
||||
that use scikit-learn from a single node to a cluster.
|
||||
|
||||
.. _`joblib`: https://joblib.readthedocs.io
|
||||
.. _`scikit-learn`: https://scikit-learn.org
|
||||
|
||||
Quickstart
|
||||
----------
|
||||
|
||||
To get started, first `install Ray <installation.html>`__, then use
|
||||
``from ray.experimental.joblib import register_ray`` and run ``register_ray()``.
|
||||
This will register Ray as a joblib backend for scikit-learn to use.
|
||||
Then run your original scikit-learn code inside
|
||||
``with joblib.parallel_backend('ray')``. This will start a local Ray cluster.
|
||||
See the `Run on a Cluster`_ section below for instructions to run on
|
||||
a multi-node Ray cluster instead.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from sklearn.svm import SVC
|
||||
digits = load_digits()
|
||||
param_space = {
|
||||
'C': np.logspace(-6, 6, 30),
|
||||
'gamma': np.logspace(-8, 8, 30),
|
||||
'tol': np.logspace(-4, -1, 30),
|
||||
'class_weight': [None, 'balanced'],
|
||||
}
|
||||
model = SVC(kernel='rbf')
|
||||
search = RandomizedSearchCV(model, param_space, cv=5, n_iter=300, verbose=10)
|
||||
|
||||
import joblib
|
||||
from ray.experimental.joblib import register_ray
|
||||
register_ray()
|
||||
with joblib.parallel_backend('ray'):
|
||||
search.fit(digits.data, digits.target)
|
||||
|
||||
Run on a Cluster
|
||||
----------------
|
||||
|
||||
This section assumes that you have a running Ray cluster. To start a Ray cluster,
|
||||
please refer to the `cluster setup <cluster-index.html>`__ instructions.
|
||||
|
||||
To connect a scikit-learn to a running Ray cluster, you have to specify the address of the
|
||||
head node by setting the ``RAY_ADDRESS`` environment variable.
|
||||
|
||||
You can also start Ray manually by calling ``ray.init()`` (with any of its supported
|
||||
configuration options) before calling ``with joblib.parallel_backend('ray')``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you do not set the ``RAY_ADDRESS`` environment variable and do not provide
|
||||
``address`` in ``ray.init(address=<address>)`` then scikit-learn will run on a SINGLE node!
|
17
python/ray/experimental/joblib/__init__.py
Normal file
17
python/ray/experimental/joblib/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from joblib.parallel import register_parallel_backend
|
||||
|
||||
|
||||
def register_ray():
|
||||
""" Register Ray Backend to be called with parallel_backend("ray"). """
|
||||
try:
|
||||
from ray.experimental.joblib.ray_backend import RayBackend
|
||||
register_parallel_backend("ray", RayBackend)
|
||||
except ImportError:
|
||||
msg = ("To use the ray backend you must install ray."
|
||||
"Try running 'pip install ray'."
|
||||
"See https://ray.readthedocs.io/en/latest/installation.html"
|
||||
"for more information.")
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
__all__ = ["register_ray"]
|
58
python/ray/experimental/joblib/ray_backend.py
Normal file
58
python/ray/experimental/joblib/ray_backend.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
from joblib.pool import PicklingPool
|
||||
import logging
|
||||
|
||||
from ray.experimental.multiprocessing.pool import Pool
|
||||
import ray
|
||||
|
||||
RAY_ADDRESS_ENV = "RAY_ADDRESS"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayBackend(MultiprocessingBackend):
|
||||
"""Ray backend uses ray, a system for scalable distributed computing.
|
||||
More info about Ray is available here: https://ray.readthedocs.io.
|
||||
"""
|
||||
|
||||
def configure(self,
|
||||
n_jobs=1,
|
||||
parallel=None,
|
||||
prefer=None,
|
||||
require=None,
|
||||
**memmappingpool_args):
|
||||
"""Make Ray Pool the father class of PicklingPool. PicklingPool is a
|
||||
father class that inherits Pool from multiprocessing.pool. The next
|
||||
line is a patch, which changes the inheritance of Pool to be from
|
||||
ray.experimental.multiprocessing.pool.
|
||||
"""
|
||||
PicklingPool.__bases__ = (Pool, )
|
||||
"""Use all available resources when n_jobs == -1. Must set RAY_ADDRESS
|
||||
variable in the environment or run ray.init(address=..) to run on
|
||||
multiple nodes.
|
||||
"""
|
||||
if n_jobs == -1:
|
||||
if not ray.is_initialized():
|
||||
import os
|
||||
if RAY_ADDRESS_ENV in os.environ:
|
||||
ray_address = os.environ[RAY_ADDRESS_ENV]
|
||||
logger.info(
|
||||
"Connecting to ray cluster at address='{}'".format(
|
||||
ray_address))
|
||||
ray.init(address=ray_address)
|
||||
else:
|
||||
logger.info("Starting local ray cluster")
|
||||
ray.init()
|
||||
ray_cpus = int(ray.state.cluster_resources()["CPU"])
|
||||
n_jobs = ray_cpus
|
||||
|
||||
eff_n_jobs = super(RayBackend, self).configure(
|
||||
n_jobs, parallel, prefer, require, **memmappingpool_args)
|
||||
return eff_n_jobs
|
||||
|
||||
def effective_n_jobs(self, n_jobs):
|
||||
eff_n_jobs = super(RayBackend, self).effective_n_jobs(n_jobs)
|
||||
if n_jobs == -1:
|
||||
ray_cpus = int(ray.state.cluster_resources()["CPU"])
|
||||
eff_n_jobs = ray_cpus
|
||||
return eff_n_jobs
|
|
@ -259,6 +259,14 @@ py_test(
|
|||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_joblib",
|
||||
size = "medium",
|
||||
srcs = ["test_joblib.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_multi_node_2",
|
||||
size = "medium",
|
||||
|
|
159
python/ray/tests/test_joblib.py
Normal file
159
python/ray/tests/test_joblib.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import numpy as np
|
||||
import joblib
|
||||
from sklearn.datasets import load_digits, load_iris
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from time import time
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.ensemble import ExtraTreesClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.kernel_approximation import Nystroem
|
||||
from sklearn.kernel_approximation import RBFSampler
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.svm import LinearSVC, SVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.utils import check_array
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
from ray.experimental.joblib import register_ray
|
||||
import ray
|
||||
|
||||
|
||||
def test_register_ray():
|
||||
register_ray()
|
||||
assert "ray" in joblib.parallel.BACKENDS
|
||||
assert not ray.is_initialized()
|
||||
|
||||
|
||||
def test_ray_backend(shutdown_only):
|
||||
register_ray()
|
||||
from ray.experimental.joblib.ray_backend import RayBackend
|
||||
with joblib.parallel_backend("ray"):
|
||||
assert type(joblib.parallel.get_active_backend()[0]) == RayBackend
|
||||
|
||||
|
||||
def test_svm_single_node(shutdown_only):
|
||||
digits = load_digits()
|
||||
param_space = {
|
||||
"C": np.logspace(-6, 6, 10),
|
||||
"gamma": np.logspace(-8, 8, 10),
|
||||
"tol": np.logspace(-4, -1, 3),
|
||||
"class_weight": [None, "balanced"],
|
||||
}
|
||||
|
||||
model = SVC(kernel="rbf")
|
||||
search = RandomizedSearchCV(
|
||||
model, param_space, cv=3, n_iter=50, verbose=10)
|
||||
register_ray()
|
||||
with joblib.parallel_backend("ray"):
|
||||
search.fit(digits.data, digits.target)
|
||||
assert ray.is_initialized()
|
||||
|
||||
|
||||
def test_svm_multiple_nodes(ray_start_cluster_2_nodes):
|
||||
digits = load_digits()
|
||||
param_space = {
|
||||
"C": np.logspace(-6, 6, 30),
|
||||
"gamma": np.logspace(-8, 8, 30),
|
||||
"tol": np.logspace(-4, -1, 30),
|
||||
"class_weight": [None, "balanced"],
|
||||
}
|
||||
|
||||
model = SVC(kernel="rbf")
|
||||
search = RandomizedSearchCV(
|
||||
model, param_space, cv=5, n_iter=100, verbose=10)
|
||||
register_ray()
|
||||
with joblib.parallel_backend("ray"):
|
||||
search.fit(digits.data, digits.target)
|
||||
assert ray.is_initialized()
|
||||
|
||||
|
||||
"""This test only makes sure the different sklearn classifiers are supported
|
||||
and do not fail. It can be improved to check for accuracy similar to
|
||||
'test_cross_validation' but the classifiers need to be improved (to improve
|
||||
the accuracy), which results in longer test time.
|
||||
"""
|
||||
|
||||
|
||||
def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
|
||||
ESTIMATORS = {
|
||||
"CART": DecisionTreeClassifier(),
|
||||
"ExtraTrees": ExtraTreesClassifier(n_estimators=10),
|
||||
"RandomForest": RandomForestClassifier(),
|
||||
"Nystroem-SVM": make_pipeline(
|
||||
Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=1)),
|
||||
"SampledRBF-SVM": make_pipeline(
|
||||
RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=1)),
|
||||
"LogisticRegression-SAG": LogisticRegression(
|
||||
solver="sag", tol=1e-1, C=1e4),
|
||||
"LogisticRegression-SAGA": LogisticRegression(
|
||||
solver="saga", tol=1e-1, C=1e4),
|
||||
"MultilayerPerceptron": MLPClassifier(
|
||||
hidden_layer_sizes=(32, 32),
|
||||
max_iter=100,
|
||||
alpha=1e-4,
|
||||
solver="sgd",
|
||||
learning_rate_init=0.2,
|
||||
momentum=0.9,
|
||||
verbose=1,
|
||||
tol=1e-2,
|
||||
random_state=1),
|
||||
"MLP-adam": MLPClassifier(
|
||||
hidden_layer_sizes=(32, 32),
|
||||
max_iter=100,
|
||||
alpha=1e-4,
|
||||
solver="adam",
|
||||
learning_rate_init=0.001,
|
||||
verbose=1,
|
||||
tol=1e-2,
|
||||
random_state=1)
|
||||
}
|
||||
# Load dataset.
|
||||
print("Loading dataset...")
|
||||
data = fetch_openml("mnist_784")
|
||||
X = check_array(data["data"], dtype=np.float32, order="C")
|
||||
y = data["target"]
|
||||
|
||||
# Normalize features.
|
||||
X = X / 255
|
||||
|
||||
# Create train-test split.
|
||||
print("Creating train-test split...")
|
||||
n_train = 6000
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
register_ray()
|
||||
|
||||
train_time = {}
|
||||
random_seed = 0
|
||||
# Use two workers per classifier.
|
||||
num_jobs = 2
|
||||
with joblib.parallel_backend("ray"):
|
||||
for name in sorted(ESTIMATORS.keys()):
|
||||
print("Training %s ... " % name, end="")
|
||||
estimator = ESTIMATORS[name]
|
||||
estimator_params = estimator.get_params()
|
||||
estimator.set_params(
|
||||
**{
|
||||
p: random_seed
|
||||
for p in estimator_params if p.endswith("random_state")
|
||||
})
|
||||
|
||||
if "n_jobs" in estimator_params:
|
||||
estimator.set_params(n_jobs=num_jobs)
|
||||
time_start = time()
|
||||
estimator.fit(X_train, y_train)
|
||||
train_time[name] = time() - time_start
|
||||
print("training", name, "took", train_time[name], "seconds")
|
||||
|
||||
|
||||
def test_cross_validation(shutdown_only):
|
||||
register_ray()
|
||||
iris = load_iris()
|
||||
clf = SVC(kernel="linear", C=1, random_state=0)
|
||||
with joblib.parallel_backend("ray", n_jobs=5):
|
||||
accuracy = cross_val_score(clf, iris.data, iris.target, cv=5)
|
||||
assert len(accuracy) == 5
|
||||
for result in accuracy:
|
||||
assert result > 0.95
|
Loading…
Add table
Reference in a new issue