Wrapping MKL Functions

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view

Wrapping MKL Functions

I'm cross-posting to scipy-user since I know there are people familiar
with both cython and the MKL here and I'm hoping someone can easily see
what I'm doing wrong.

I've tried my hand at wrapping the MKL axpby function, for both float and
complex types. For float types I was successful:

In [10]: from numpy.random import randn
         a = 2
         x = randn(1e6)
         b = 3
         y = randn(1e6)

In [11]: np.allclose(a*x+b*y, axpby(a, x, b, y))

In [42]: %timeit a*x+b*y
10 loops, best of 3: 20.5 ms per loop

In [43]: %timeit axpby(a, x, b, y)
100 loops, best of 3: 8.86 ms per loop

But for complex types I got the following error:

ValueError                                Traceback (most recent call last)
<ipython-input-4-a0896becb5ce> in <module>()
----> 1 axpby(a,x,b,y)

a0ed.pyd in a0ed.axpby (a0ed.c:2674)()

ValueError: No value specified for struct attribute 'real'

I'm assuming that's because I'm passing in a pointer to a np.complex64_t
type which has a different layout to the MKL_Complex16 struct, but I'm not
quite sure how to work around that so I'd be keen to find out if anyone

Also any pointers on best practice with respect to my function below would
be greatly appreciated. I went down the route of having a (higher
performance?) fused type cpdef function which AFAICT requires all inputs to
be the same type. I then have a wapper cdef function which will (up-)cast the
inputs if required.

The mkl typedefs are:

/* MKL Complex type for single precision */
#ifndef MKL_Complex8
struct _MKL_Complex8 {
    float real;
    float imag;
} MKL_Complex8;

/* MKL Complex type for double precision */
#ifndef MKL_Complex16
struct _MKL_Complex16 {
    double real;
    double imag;
} MKL_Complex16;


%%cython --force -I=C:\dev\bin\Intel\ComposerXE-2011\mkl\include -
l=C:\dev\bin\Intel\ComposerXE-2011\mkl\lib\ia32\mkl_rt -
cimport cython

from cpython cimport bool

import numpy as np
cimport numpy as np

ctypedef np.int8_t int8_t
ctypedef np.int32_t int32_t
ctypedef np.int64_t int64_t
ctypedef np.float32_t float32_t
ctypedef np.float64_t float64_t
ctypedef np.complex_t complex32_t
ctypedef np.complex64_t complex64_t

cdef extern from "mkl.h" nogil:
    ctypedef struct MKL_Complex8:
        float32_t real
        float32_t imag
    ctypedef  struct MKL_Complex16:
        float64_t real
        float64_t imag

ctypedef fused mkl_float:

cdef extern from * nogil:
    ctypedef int32_t const_mkl_int "const int32_t"
    ctypedef float32_t const_float32 "const float32_t"
    ctypedef float64_t const_float64 "const float64_t"
    ctypedef MKL_Complex8* const_complex32_ptr "const MKL_Complex8*"
    ctypedef MKL_Complex16* const_complex64_ptr "const MKL_Complex16*"

cdef extern from "mkl.h" nogil:
    void saxpby(const_mkl_int *size,
                const_float32 *a,
                const_float32 *x,
                const_mkl_int *xstride,
                const_float32 *b,
                const_float32 *y,
                const_mkl_int *ystride)
    void daxpby(const_mkl_int *size,
                const_float64 *a,
                const_float64 *x,
                const_mkl_int *xstride,
                const_float64 *b,
                const_float64 *y,
                const_mkl_int *ystride)
    void caxpby(const_mkl_int *size,
                const_complex32_ptr a,
                const_complex32_ptr x,
                const_mkl_int *xstride,
                const_complex32_ptr b,
                const_complex32_ptr y,
                const_mkl_int *ystride)
    void zaxpby(const_mkl_int *size,
                const_complex64_ptr a,
                const_complex64_ptr x,
                const_mkl_int *xstride,
                const_complex64_ptr b,
                const_complex64_ptr y,
                const_mkl_int *ystride)


cpdef _axpby(mkl_float a, mkl_float[:] x, mkl_float b, mkl_float[:] y):
    cdef int32_t size = x.shape[0]
    cdef int32_t xstride = x.strides[0]/x.itemsize
    cdef int32_t ystride = y.strides[0]/y.itemsize
    if mkl_float is float32_t:
        saxpby(&size, &a, &x[0], &xstride, &b, &y[0], &ystride)
    elif mkl_float is float64_t:
        daxpby(&size, &a, &x[0], &xstride, &b, &y[0], &ystride)
    elif mkl_float is MKL_Complex8:
        caxpby(&size, &a, &x[0], &xstride, &b, &y[0], &ystride)
    elif mkl_float is MKL_Complex16:
        zaxpby(&size, &a, &x[0], &xstride, &b, &y[0], &ystride)

def axpby(a, x, b, y, bool overwrite_y=False):
    if (type(a) == np.complex64) or (type(b) == np.complex64) or \
       (x.dtype.type == np.complex64) or (y.dtype.type == np.complex64):
        x = np.asarray(x, dtype=np.complex64)
        y = np.array(y, dtype=np.complex64, copy=~overwrite_y)
        a = np.complex64(a)
        b = np.complex64(b)
        _axpby[MKL_Complex16](a, x, b, y)
    elif (type(a) == np.complex) or (type(b) == np.complex) or \
       (x.dtype.type == np.complex) or (y.dtype.type == np.complex):
        x = np.asarray(x, dtype=np.complex)
        y = np.array(y, dtype=np.complex, copy=~overwrite_y)
        a = np.complex(a)
        b = np.complex(b)
        _axpby[MKL_Complex8](a, x, b, y)
    elif (x.dtype.type == np.float64) or (y.dtype.type == np.float64):
        x = np.asarray(x, dtype=np.float64)
        y = np.array(y, dtype=np.float64, copy=~overwrite_y)
        a = np.float64(a)
        b = np.float64(b)
        _axpby[float64_t](a, x, b, y)
    elif (x.dtype.type == np.float32) or (y.dtype.type == np.float32):
        x = np.asarray(x, dtype=np.float32)
        y = np.array(y, dtype=np.float32, copy=~overwrite_y)
        a = np.float32(a)
        b = np.float32(b)
        _axpby[float32_t](a, x, b, y)
        raise Exception()
    return y

SciPy-User mailing list
[hidden email]