Add basic functionality for Cython functions and actors (#1193)

* Add basic functionality for Cython functions and actors

* Fix up per @pcmoritz comments

* Fixes per @richardliaw comments

* Fixes per @robertnishihara comments

* Forgot double quotes when updating masked_log

* Remove import typing for Python 2 compatibility
This commit is contained in:
Daniel Suo 2017-11-09 20:49:06 -05:00 committed by Philipp Moritz
parent 11f8f8bd8c
commit 4f0da6f81c
16 changed files with 1047 additions and 13 deletions

4
.gitignore vendored
View file

@ -48,6 +48,9 @@
*.dylib
*.dll
# Cython-generated files
*.c
# Incremental linking files
*.ilk
@ -95,6 +98,7 @@ scripts/nodes.txt
# CMake
cmake-build-debug/
build
# Python setup files
*.egg-info

View file

@ -91,6 +91,7 @@ install:
- ./.travis/install-dependencies.sh
- export PATH="$HOME/miniconda/bin:$PATH"
- ./.travis/install-ray.sh
- ./.travis/install-cython-examples.sh
- cd python/ray/core
- bash ../../../src/common/test/run_tests.sh
@ -120,6 +121,7 @@ script:
- python test/monitor_test.py
- python test/trial_runner_test.py
- python test/trial_scheduler_test.py
- python test/cython_test.py
- python -m pytest python/ray/rllib/test/test_catalog.py

View file

@ -0,0 +1,37 @@
#!/usr/bin/env bash
# Cause the script to exit if a single command fails
set -e
ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd)
echo "PYTHON is $PYTHON"
cython_examples="$ROOT_DIR/../examples/cython"
if [[ "$PYTHON" == "2.7" ]]; then
pushd $cython_examples
pip install scipy
python setup.py install --user
popd
elif [[ "$PYTHON" == "3.5" ]]; then
export PATH="$HOME/miniconda/bin:$PATH"
pushd $cython_examples
pip install scipy
python setup.py install --user
popd
elif [[ "$LINT" == "1" ]]; then
export PATH="$HOME/miniconda/bin:$PATH"
pushd $cython_examples
python setup.py install --user
popd
else
echo "Unrecognized Python version."
exit 1
fi

View file

@ -0,0 +1,36 @@
Cython
======
Getting Started
---------------
This document provides examples of using Cython-generated code in ``ray``. To get started, run the following from directory ``$RAY_HOME/examples/cython``:
.. code-block:: bash
pip install scipy # For BLAS example
python setup.py develop
python cython_main.py --help
You can import the ``cython_examples`` module from a Python script or interpreter.
Notes
-----
* You **must** include the following two lines at the top of any ``*.pyx`` file:
.. code-block:: python
#!python
# cython: embedsignature=True, binding=True
* You cannot decorate Cython functions within a ``*.pyx`` file (there are ways around this, but creates a leaky abstraction between Cython and Python that would be very challenging to support generally). Instead, prefer the following in your Python code:
.. code-block:: python
some_cython_func = ray.remote(some_cython_module.some_cython_func)
* You cannot transfer memory buffers to a remote function (see ``example8``, which currently fails); your remote function must return a value
* Have a look at ``cython_main.py``, ``cython_simple.pyx``, and ``setup.py`` for examples of how to call, define, and build Cython code, respectively. The Cython `documentation <http://cython.readthedocs.io/>`_ is also very helpful.
* Several limitations come from Cython's own `unsupported <https://github.com/cython/cython/wiki/Unsupported>`_ Python features.
* We currently do not support compiling and distributing Cython code to ``ray`` clusters. In other words, Cython developers are responsible for compiling and distributing any Cython code to their cluster (much as would be the case for users who need Python packages like ``scipy``).
* For most simple use cases, developers need not worry about Python 2 or 3, but users who do need to care can have a look at the ``language_level`` Cython compiler directive (see `here <http://cython.readthedocs.io/en/latest/src/reference/compilation.html>`_).

View file

@ -57,6 +57,7 @@ Example Program
example-a3c.rst
example-lbfgs.rst
example-evolution-strategies.rst
example-cython.rst
using-ray-with-tensorflow.rst
.. toctree::

View file

@ -0,0 +1,31 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .cython_simple import simple_func, fib, fib_int, \
fib_cpdef, fib_cdef, simple_class
from .masked_log import masked_log
from .cython_blas import \
compute_self_corr_for_voxel_sel, \
compute_kernel_matrix, \
compute_single_self_corr_syrk, \
compute_single_self_corr_gemm, \
compute_corr_vectors, \
compute_single_matrix_multiplication
__all__ = [
"simple_func",
"fib",
"fib_int",
"fib_cpdef",
"fib_cdef",
"simple_class",
"masked_log",
"compute_self_corr_for_voxel_sel",
"compute_kernel_matrix",
"compute_single_self_corr_syrk",
"compute_single_self_corr_gemm",
"compute_corr_vectors",
"compute_single_matrix_multiplication"
]

View file

@ -0,0 +1,563 @@
#!python
# cython: embedsignature=True, binding=True
# Copyright 2016 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Authors: Yida Wang
# (Intel Labs), 2016
cimport scipy.linalg.cython_blas as blas
def compute_self_corr_for_voxel_sel(py_trans_a, py_trans_b, py_m, py_n, py_k,
py_alpha, py_a, py_lda, int py_start_voxel,
py_b, py_ldb, py_beta, py_c, py_ldc,
int py_start_epoch):
""" use blas API sgemm wrapped by scipy to compute correlation
This method is limited to process self-correlation.
The blas APIs process matrices in column-major,
but our matrices are in row-major,
so we play the transpose trick here, i.e. A*B=(B^T*A^T)^T.
The resulting matrix in shape [num_assigned_voxels, num_voxels]
is stored in an alternate way to make sure that
the correlation vectors of the same voxel stored continuously
Parameters
----------
py_trans_a: str
do transpose or not for the first matrix A
py_trans_b: str
do transpose or not for the first matrix B
py_m: int
the row of the resulting matrix C
in our case, is num_voxels
py_n: int
the column of the resulting matrix C
in our case, is num_assigned_voxels
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
py_alpha: float
the weight applied to the first matrix A
py_a: 2D array in shape [epoch_length, num_voxels]
It is the activity data of an epoch, part 1 of the data to be
correlated with. Note that py_a can point to the same location of py_b.
py_lda: int
the stride of the first matrix A
py_start_voxel: int
the starting voxel of assigned voxels
used to locate the second matrix B
py_b: 2D array in shape [epoch_length, num_voxels]
It is the activity data of an epoch, part 2 of the data to be
correlated with. Note that py_a can point to the same location of py_b.
py_ldb: int
the stride of the second matrix B
py_beta: float
the weight applied to the resulting matrix C
py_c: 3D array in shape [num_selected_voxels, num_epochs, num_voxels]
place to store the resulting correlation values
py_ldc: int
the stride of the resulting matrix
in our case, num_voxels*num_epochs
py_start_epoch: int
the epoch over which the correlation is computed
Returns
-------
py_c: 3D array in shape [num_selected_voxels, num_epochs, num_voxels]
write the resulting correlation values in an alternate way
for the processing epoch
"""
cdef bytes by_trans_a=py_trans_a.encode()
cdef bytes by_trans_b=py_trans_b.encode()
cdef char* trans_a = by_trans_a
cdef char* trans_b = by_trans_b
cdef int M, N, K, lda, ldb, ldc
M = py_m
N = py_n
K = py_k
lda = py_lda
ldb = py_ldb
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, ::1] A
A = py_a
cdef float[:, ::1] B
B = py_b
cdef float[:, :, ::1] C
C = py_c
blas.sgemm(trans_a, trans_b, &M, &N, &K, &alpha, &A[0, 0], &lda,
&B[0, py_start_voxel], &ldb, &beta, &C[0, py_start_epoch, 0], &ldc)
def compute_kernel_matrix(py_uplo, py_trans, py_n, py_k, py_alpha, py_a,
int py_start_voxel, py_lda,
py_beta, py_c, py_ldc):
""" use blas API syrk wrapped by scipy to compute kernel matrix of SVM
The blas APIs process matrices in column-major, but our matrices are
in row-major, so we play the transpose trick here, i.e. A*B=(B^T*A^T)^T
In SVM with linear kernel, the distance of two samples
is essentially the dot product of them.
Therefore, the kernel matrix can be obtained by matrix multiplication.
Since the kernel matrix is symmetric, ssyrk is used,
the other half of the matrix is assigned later.
In our case, the dimension of samples is much larger than
the number samples, so we proportionally shrink the values of
the kernel matrix for getting more robust alpha values in SVM iteration.
Parameters
----------
py_uplo: str
getting the upper or lower triangle of the matrix
py_trans: str
do transpose or not for the input matrix A
py_n: int
the row and column of the resulting matrix C
in our case, is num_epochs
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
in our case, is num_voxels
py_alpha: float
the weight applied to the input matrix A
py_a: 3D array in shape [num_assigned_voxels, num_epochs, num_voxels]
in our case the normalized correlation values of a voxel
py_start_voxel: int
the processed voxel
used to locate the input matrix A
py_lda: int
the stride of the input matrix A
py_beta: float
the weight applied to the resulting matrix C
py_c: 2D array in shape [num_epochs, num_epochs]
place to store the resulting kernel matrix
py_ldc: int
the stride of the resulting matrix
Returns
-------
py_c: 2D array in shape [num_epochs, num_epochs]
write the resulting kernel_matrix
for the processing voxel
"""
cdef bytes by_uplo=py_uplo.encode()
cdef bytes by_trans=py_trans.encode()
cdef char* uplo = by_uplo
cdef char* trans = by_trans
cdef int N, K, lda, ldc
N = py_n
K = py_k
lda = py_lda
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, :, ::1] A
A = py_a
cdef float[:, ::1] C
C = py_c
blas.ssyrk(uplo, trans, &N, &K, &alpha, &A[py_start_voxel, 0, 0], &lda,
&beta, &C[0, 0], &ldc)
# complete the other half of the kernel matrix
if py_uplo == 'L':
for j in range(py_c.shape[0]):
for k in range(j):
py_c[j, k] = py_c[k, j]
else:
for j in range(py_c.shape[0]):
for k in range(j):
py_c[k, j] = py_c[j, k]
def compute_single_self_corr_syrk(py_uplo, py_trans, py_n, py_k,
py_alpha, py_a, py_lda,
py_beta, py_c, py_ldc,
int py_start_sample):
""" use blas API syrk wrapped by scipy to compute correlation matrix
This is to compute the correlation between selected voxels for
final training and classification. Since the resulting correlation
matrix is symmetric, syrk is used. However, it looks like that in most
cases, syrk performs much worse than gemm (the next function).
Here we assume that the resulting matrix is stored in a compact way,
i.e. py_ldc == py_n.
Parameters
----------
py_uplo: str
getting the upper or lower triangle of the matrix
py_trans: str
do transpose or not for the input matrix A
py_n: int
the row and column of the resulting matrix C
in our case, is num_selected_voxels
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
in our case, is num_TRs
py_alpha: float
the weight applied to the input matrix A
py_a: 2D array in shape [num_TRs, num_selected_voxels]
in our case the normalized activity values
py_lda: int
the stride of the input matrix A
py_beta: float
the weight applied to the resulting matrix C
py_c: 3D array
in shape [num_samples, num_selected_voxels, num_selected_voxels]
place to store the resulting kernel matrix
py_ldc: int
the stride of the resulting matrix
py_start_sample: int
the processed sample
used to locate the resulting matrix C
Returns
-------
py_c: 3D array
in shape [num_samples, num_selected_voxels, num_selected_voxels]
write the resulting correlation matrices
for the processed sample
"""
cdef bytes by_uplo=py_uplo.encode()
cdef bytes by_trans=py_trans.encode()
cdef char* uplo = by_uplo
cdef char* trans = by_trans
cdef int N, K, lda, ldc
N = py_n
K = py_k
lda = py_lda
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, ::1] A
A = py_a
cdef float[:, :, ::1] C
C = py_c
blas.ssyrk(uplo, trans, &N, &K, &alpha, &A[0, 0], &lda,
&beta, &C[py_start_sample, 0, 0], &ldc)
# complete the other half of the kernel matrix
if py_uplo == 'L':
for j in range(py_c.shape[1]):
for k in range(j):
py_c[py_start_sample, j, k] = py_c[py_start_sample, k, j]
else:
for j in range(py_c.shape[1]):
for k in range(j):
py_c[py_start_sample, k, j] = py_c[py_start_sample, j, k]
def compute_single_self_corr_gemm(py_trans_a, py_trans_b, py_m, py_n,
py_k, py_alpha, py_a, py_lda,
py_ldb, py_beta, py_c, py_ldc,
int py_start_sample):
""" use blas API gemm wrapped by scipy to compute correlation matrix
This is to compute the correlation between selected voxels for
final training and classification. Although the resulting correlation
matrix is symmetric, in most cases, gemm performs better than syrk.
Here we assume that the resulting matrix is stored in a compact way,
i.e. py_ldc == py_n.
Parameters
----------
py_trans_a: str
do transpose or not for the first matrix A
py_trans_b: str
do transpose or not for the first matrix B
py_m: int
the row of the resulting matrix C
in our case, is num_selected_voxels
py_n: int
the column of the resulting matrix C
in our case, is num_selected_voxels
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
in our case, is num_TRs
py_alpha: float
the weight applied to the input matrix A
py_a: 2D array in shape [num_TRs, num_selected_voxels]
in our case the normalized activity values
both multipliers are specified here as the same one
py_lda: int
the stride of the input matrix A
py_ldb: int
the stride of the input matrix B
in our case, the same as py_lda
py_beta: float
the weight applied to the resulting matrix C
py_c: 3D array
in shape [num_samples, num_selected_voxels, num_selected_voxels]
place to store the resulting kernel matrix
py_ldc: int
the stride of the resulting matrix
py_start_sample: int
the processed sample
used to locate the resulting matrix C
Returns
-------
py_c: 3D array
in shape [num_samples, num_selected_voxels, num_selected_voxels]
write the resulting correlation matrices
for the processed sample
"""
cdef bytes by_trans_a=py_trans_a.encode()
cdef bytes by_trans_b=py_trans_b.encode()
cdef char* trans_a = by_trans_a
cdef char* trans_b = by_trans_b
cdef int M, N, K, lda, ldb, ldc
M = py_m
N = py_n
K = py_k
lda = py_lda
ldb = py_ldb
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, ::1] A
A = py_a
cdef float[:, :, ::1] C
C = py_c
blas.sgemm(trans_a, trans_b, &M, &N, &K, &alpha, &A[0, 0], &lda,
&A[0, 0], &ldb, &beta, &C[py_start_sample, 0, 0], &ldc)
def compute_corr_vectors(py_trans_a, py_trans_b, py_m, py_n,
py_k, py_alpha, py_a, py_lda,
py_b, py_ldb, py_beta, py_c, py_ldc,
int py_start_voxel,
int py_start_sample):
""" use blas API gemm wrapped by scipy to construct a correlation vector
The correlation vector is essentially correlation matrices computed
from two activity matrices. It will be placed in the corresponding place
of the resulting correlation data set.
The blas APIs process matrices in column-major,
but our matrices are in row-major, so we play the transpose trick here,
i.e. A*B=(B^T*A^T)^T
py_trans_a: str
do transpose or not for the first matrix A
py_trans_b: str
do transpose or not for the first matrix B
py_m: int
the row of the resulting matrix C
py_n: int
the column of the resulting matrix C
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
py_alpha: float
the weight applied to the input matrix A
py_a: 2D array
py_lda: int
the stride of the input matrix A
py_b: 2D array
py_ldb: int
the stride of the input matrix B
py_beta: float
the weight applied to the resulting matrix C
py_c: 2D array
in shape [py_m, py_n] of column-major
in fact it is
in shape [py_n, py_m] of row-major
py_ldc: int
the stride of the resulting matrix
py_start_voxel: int
the starting voxel of assigned voxels
used to locate the second matrix B
py_start_sample: int
the processed sample
used to locate the resulting matrix C
Returns
-------
py_c: 2D array
in shape [py_m, py_n] of column-major
write the resulting matrix to the place indicated by py_start_sample
"""
cdef bytes by_trans_a=py_trans_a.encode()
cdef bytes by_trans_b=py_trans_b.encode()
cdef char* trans_a = by_trans_a
cdef char* trans_b = by_trans_b
cdef int M, N, K, lda, ldb, ldc
M = py_m
N = py_n
K = py_k
lda = py_lda
ldb = py_ldb
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, ::1] A
A = py_a
cdef float[:, ::1] B
B = py_b
cdef float[:, :, ::1] C
C = py_c
blas.sgemm(trans_a, trans_b, &M, &N, &K, &alpha, &A[0, 0], &lda,
&B[0, py_start_voxel], &ldb, &beta, &C[py_start_sample, 0, 0], &ldc)
def compute_single_matrix_multiplication(py_trans_a, py_trans_b, py_m, py_n,
py_k, py_alpha, py_a, py_lda,
py_b, py_ldb, py_beta, py_c, py_ldc):
""" use blas API gemm wrapped by scipy to do matrix multiplication
This is to compute the matrix multiplication.
The blas APIs process matrices in column-major,
but our matrices are in row-major, so we play the transpose trick here,
i.e. A*B=(B^T*A^T)^T
Parameters
----------
py_trans_a: str
do transpose or not for the first matrix A
py_trans_b: str
do transpose or not for the first matrix B
py_m: int
the row of the resulting matrix C
py_n: int
the column of the resulting matrix C
py_k: int
the collapsed dimension of the multiplying matrices
i.e. the column of the first matrix after transpose if necessary
the row of the second matrix after transpose if necessary
py_alpha: float
the weight applied to the input matrix A
py_a: 2D array
py_lda: int
the stride of the input matrix A
py_b: 2D array
py_ldb: int
the stride of the input matrix B
py_beta: float
the weight applied to the resulting matrix C
py_c: 2D array
in shape [py_m, py_n] of column-major
in fact it is
in shape [py_n, py_m] of row-major
py_ldc: int
the stride of the resulting matrix
Returns
-------
py_c: 2D array
in shape [py_m, py_n] of column-major
write the resulting matrix
"""
cdef bytes by_trans_a=py_trans_a.encode()
cdef bytes by_trans_b=py_trans_b.encode()
cdef char* trans_a = by_trans_a
cdef char* trans_b = by_trans_b
cdef int M, N, K, lda, ldb, ldc
M = py_m
N = py_n
K = py_k
lda = py_lda
ldb = py_ldb
ldc = py_ldc
cdef float alpha, beta
alpha = py_alpha
beta = py_beta
cdef float[:, ::1] A
A = py_a
cdef float[:, ::1] B
B = py_b
cdef float[:, ::1] C
C = py_c
blas.sgemm(trans_a, trans_b, &M, &N, &K, &alpha, &A[0, 0], &lda,
&B[0, 0], &ldb, &beta, &C[0, 0], &ldc)

View file

@ -0,0 +1,43 @@
#!python
# cython: embedsignature=True, binding=True
def simple_func(x, y, z):
return x + y + z
# Cython code directly callable from Python
def fib(n):
if n < 2:
return n
return fib(n-2) + fib(n-1)
# Typed Cython code
def fib_int(int n):
if n < 2:
return n
return fib_int(n-2) + fib_int(n-1)
# Cython-Python code
cpdef fib_cpdef(int n):
if n < 2:
return n
return fib_cpdef(n-2) + fib_cpdef(n-1)
# C code
def fib_cdef(int n):
return fib_in_c(n)
cdef int fib_in_c(int n):
if n < 2:
return n
return fib_in_c(n-2) + fib_in_c(n-1)
# Simple class
class simple_class(object):
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value

View file

@ -0,0 +1,48 @@
#!python
# cython: embedsignature=True, binding=True
# Copyright 2016 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from libc.math cimport log
import numpy as np
cimport numpy as np
def masked_log(x):
"""Compute natural logarithm while accepting nonpositive input
For nonpositive elements, return -inf.
Modified slightly from the original BrainIAK code to support
Python 2.
Parameters
----------
x: ndarray[T]
Returns
-------
ndarray[Union[T, np.float64]]
"""
y = np.empty(x.shape, dtype=np.float64)
lim = x.shape[0]
for i in range(lim):
if x[i] <= 0:
y[i] = float("-inf")
else:
y[i] = log(x[i])
return y

View file

@ -0,0 +1,120 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
import click
import inspect
import numpy as np
import cython_examples as cyth
def run_func(func, *args, **kwargs):
"""Helper function for running examples"""
ray.init()
func = ray.remote(func)
# NOTE: kwargs not allowed for now
result = ray.get(func.remote(*args))
# Inspect the stack to get calling example
caller = inspect.stack()[1][3]
print("%s: %s" % (caller, str(result)))
return result
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
def cli():
"""Working with Cython actors and functions in Ray"""
@cli.command()
def example1():
"""Cython def function"""
run_func(cyth.simple_func, 1, 2, 3)
@cli.command()
def example2():
"""Cython def function, recursive"""
run_func(cyth.fib, 10)
@cli.command()
def example3():
"""Cython def function, built-in typed parameter"""
# NOTE: Cython will attempt to cast argument to correct type
# NOTE: Floats will be cast to int, but string, for example will error
run_func(cyth.fib_int, 10)
@cli.command()
def example4():
"""Cython cpdef function"""
run_func(cyth.fib_cpdef, 10)
@cli.command()
def example5():
"""Cython wrapped cdef function"""
# NOTE: cdef functions are not exposed to Python
run_func(cyth.fib_cdef, 10)
@cli.command()
def example6():
"""Cython simple class"""
ray.init()
cls = ray.remote(cyth.simple_class)
a1 = cls.remote()
a2 = cls.remote()
result1 = ray.get(a1.increment.remote())
result2 = ray.get(a2.increment.remote())
print(result1, result2)
@cli.command()
def example7():
"""Cython with function from BrainIAK (masked log)"""
run_func(cyth.masked_log, np.array([-1.0, 0.0, 1.0, 2.0]))
@cli.command()
def example8():
"""Cython with blas. NOTE: requires scipy"""
# See cython_blas.pyx for argument documentation
mat = np.array([[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]],
dtype=np.float32)
result = np.zeros((2, 2), np.float32, order="C")
run_func(cyth.compute_kernel_matrix,
"L",
"T",
2,
2,
1.0,
mat,
0,
2,
1.0,
result,
2
)
if __name__ == "__main__":
cli()

35
examples/cython/setup.py Normal file
View file

@ -0,0 +1,35 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from setuptools import setup
from Cython.Build import cythonize
import numpy
pkg_dir = "cython_examples"
modules = ["cython_simple.pyx", "masked_log.pyx"]
install_requires = ["cython", "numpy"]
include_dirs = [numpy.get_include()]
# TODO: Need scipy to run BrainIAK example, but don't want to add additional
# dependencies
try:
import scipy # noqa
modules.append("cython_blas.pyx")
install_requires.append("scipy")
except ImportError as e: # noqa
pass
modules = [os.path.join(pkg_dir, module) for module in modules]
setup(
name=pkg_dir,
version="0.0.1",
description="Cython examples for Ray",
packages=[pkg_dir],
ext_modules=cythonize(modules),
install_requires=install_requires,
include_dirs=include_dirs
)

View file

@ -15,7 +15,7 @@ import ray.local_scheduler
import ray.signature as signature
import ray.worker
from ray.utils import (binary_to_hex, FunctionProperties, random_string,
release_gpus_in_use, select_local_scheduler)
release_gpus_in_use, select_local_scheduler, is_cython)
def random_actor_id():
@ -261,7 +261,8 @@ def fetch_and_register_actor(actor_class_key, worker):
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))
inspect.ismethod(x) or
is_cython(x))))
for actor_method_name, actor_method in actor_methods:
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
@ -682,7 +683,8 @@ def actor_handle_from_class(Class, class_id, num_cpus, num_gpus,
# Get the actor methods of the given class.
actor_methods = inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))
inspect.ismethod(x) or
is_cython(x))))
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.

View file

@ -5,6 +5,8 @@ from __future__ import print_function
from collections import namedtuple
import funcsigs
from ray.utils import is_cython
FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
"arg_defaults",
"arg_is_positionals",
@ -27,6 +29,42 @@ Attributes:
"""
def get_signature_params(func):
"""Get signature parameters
Support Cython functions by grabbing relevant attributes from the Cython
function and attaching to a no-op function. This is somewhat brittle, since
funcsigs may change, but given that funcsigs is written to a PEP, we hope
it is relatively stable. Future versions of Python may allow overloading
the inspect 'isfunction' and 'ismethod' functions / create ABC for Python
functions. Until then, it appears that Cython won't do anything about
compatability with the inspect module.
Args:
func: The function whose signature should be checked.
Raises:
TypeError: A type error if the signature is not supported
"""
# The first condition for Cython functions, the latter for Cython instance
# methods
if is_cython(func):
attrs = ["__code__", "__annotations__",
"__defaults__", "__kwdefaults__"]
if all([hasattr(func, attr) for attr in attrs]):
original_func = func
def func(): return
for attr in attrs:
setattr(func, attr, getattr(original_func, attr))
else:
raise TypeError("{0!r} is not a Python function we can process"
.format(func))
return list(funcsigs.signature(func).parameters.items())
def check_signature_supported(func, warn=False):
"""Check if we support the signature of this function.
@ -43,8 +81,7 @@ def check_signature_supported(func, warn=False):
Exception: An exception is raised if the signature is not supported.
"""
function_name = func.__name__
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
sig_params = get_signature_params(func)
has_vararg_param = False
has_kwargs_param = False
@ -88,8 +125,7 @@ def extract_signature(func, ignore_first=False):
A function signature object, which includes the names of the keyword
arguments as well as their default values.
"""
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
sig_params = get_signature_params(func)
if ignore_first:
if len(sig_params) == 0:

View file

@ -12,6 +12,21 @@ import sys
import ray.local_scheduler
def is_cython(obj):
"""Check if an object is a Cython function or method"""
# TODO(suo): We could split these into two functions, one for Cython
# functions and another for Cython methods.
# TODO(suo): There doesn't appear to be a Cython function 'type' we can
# check against via isinstance. Please correct me if I'm wrong.
def check_cython(x):
return type(x).__name__ == "cython_function_or_method"
# Check if function or method, respectively
return check_cython(obj) or \
(hasattr(obj, "__func__") and check_cython(obj.__func__))
def random_string():
"""Generate a random string to use as an ID.

View file

@ -28,7 +28,8 @@ import ray.services as services
import ray.signature as signature
import ray.local_scheduler
import ray.plasma
from ray.utils import FunctionProperties, random_string, binary_to_hex
from ray.utils import (FunctionProperties, random_string, binary_to_hex,
is_cython)
SCRIPT_MODE = 0
WORKER_MODE = 1
@ -2281,6 +2282,7 @@ def export_remote_function(function_id, func_name, func, func_invoker,
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(func):
func.__globals__[func.__name__] = func_invoker
try:
pickled_func = pickle.dumps(func)
@ -2330,9 +2332,11 @@ def compute_function_id(func_name, func):
function_id_hash.update(func_name.encode("ascii"))
# If we are running a script or are in IPython, include the source code in
# the hash. If we are in a regular Python interpreter we skip this part
# because the source code is not accessible.
# because the source code is not accessible. If the function is a built-in
# (e.g., Cython), the source code is not accessible.
import __main__ as main
if hasattr(main, "__file__") or in_ipython():
if (hasattr(main, "__file__") or in_ipython()) \
and inspect.isfunction(func):
function_id_hash.update(inspect.getsource(func).encode("ascii"))
# Compute the function ID.
function_id = function_id_hash.digest()
@ -2364,7 +2368,7 @@ def remote(*args, **kwargs):
num_custom_resource, max_calls,
checkpoint_interval, func_id=None):
def remote_decorator(func_or_class):
if inspect.isfunction(func_or_class):
if inspect.isfunction(func_or_class) or is_cython(func_or_class):
function_properties = FunctionProperties(
num_return_vals=num_return_vals,
num_cpus=num_cpus,
@ -2420,7 +2424,7 @@ def remote(*args, **kwargs):
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
if sys.version_info >= (3, 0):
if sys.version_info >= (3, 0) or is_cython(func):
func_invoker.__doc__ = func.__doc__
else:
func_invoker.func_doc = func.func_doc

57
test/cython_test.py Normal file
View file

@ -0,0 +1,57 @@
from __future__ import absolute_import
from __future__ import print_function
import math
import numpy as np
import unittest
import ray
import cython_examples as cyth
def get_ray_result(cython_func, *args):
func = ray.remote(cython_func)
return ray.get(func.remote(*args))
class CythonTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
def assertEqualHelper(self, cython_func, expected, *args):
self.assertEqual(get_ray_result(cython_func, *args), expected)
def test_simple_func(self):
self.assertEqualHelper(cyth.simple_func, 6, 1, 2, 3)
self.assertEqualHelper(cyth.fib, 55, 10)
self.assertEqualHelper(cyth.fib_int, 55, 10)
self.assertEqualHelper(cyth.fib_cpdef, 55, 10)
self.assertEqualHelper(cyth.fib_cdef, 55, 10)
def test_simple_class(self):
cls = ray.remote(cyth.simple_class)
a1 = cls.remote()
a2 = cls.remote()
result1 = ray.get(a1.increment.remote())
result2 = ray.get(a2.increment.remote())
result3 = ray.get(a2.increment.remote())
self.assertEqual(result1, 1)
self.assertEqual(result2, 1)
self.assertEqual(result3, 2)
def test_numpy(self):
array = np.array([-1.0, 0.0, 1.0, 2.0])
answer = [float("-inf") if x <= 0 else math.log(x) for x in array]
result = get_ray_result(cyth.masked_log, array)
np.testing.assert_array_equal(answer, result)
if __name__ == "__main__":
unittest.main(verbosity=2)