[SOLVED] Speed up Cython implementation of dot product multiplication

Issue

I’m trying to learn cython by trying to outperform Numpy at dot product operation np.dot(a,b). But my implementation is about 4x slower.

So, this is my hello.pyx file cython implementation:

cimport numpy as cnp
cnp.import_array()

cpdef double dot_product(double[::1] vect1, double[::1] vect2):
    cdef int size = vect1.shape[0]
    cdef double result = 0
    cdef int i = 0
    while i < size:
        result += vect1[i] * vect2[i]
        i += 1
    return result

This is my .py test file:

import timeit

setup = '''
import numpy as np
import hello

n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
'''
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100

print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')

Console output:

Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.

Command to build hello.pyx:

python setup.py build_ext --inplace

setup.py file:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])
setup(ext_modules=cythonize(ext))

Processor:
i7-7700T @ 2.90 GHz

Solution

The problem mainly comes from the lack of SIMD instructions (due to both the bound-checking and the inefficient default compiler flags) compared to Numpy (which use OpenBLAS on most platforms by default).

To fix that, you should first add the following line in the beginning of the hello.pix file:

#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False

Then, you should use this new setup.py file:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])
setup(ext_modules=cythonize(ext))

Note that the flags are dependent of the compiler. That being said, both Clang and GCC support them (and probably ICC too). -O3 tells Clang and GCC to use more aggressive optimization like the automatic vectorization of the code. -mavx tells them to use the AVX instruction set (which is only available on relatively recent x86-64 processors). -ffast-math tells them to assume that floating-point number operations are associative (which is not the case) and that you only use finite/basic numbers (no NaN, nor infinities). If the above assumption are not fulfilled, then the program can crash at runtime, so be careful about such flags.

Note that OpenBLAS automatically selects the instruction set based on your machine and AFAIK it does not use -ffast-math but a safer (low-level) alternative.


Results:

Here are results on my machine:

Before optimization:
  Lightning fast time: 0.10018469997703505.
  Numpy time: 0.024747799989199848.

After (with GCC):
  Lightning fast time: 0.02865879996534204.
  Numpy time: 0.02456870001878997.

After (with Clang):
  Lightning fast time: 0.01965239998753532.
  Numpy time: 0.024799799984975834.

The code produced by Clang is faster than Numpy on my machine.


Under the hood

An analysis of the assembly code executed by the processor on my machine show that the code only use slow scalar instruction, contains unnecessary bound-checks and is mainly limited by the result += ... operation (because of a loop carried dependency).

162e3:┌─→movsd  xmm0,QWORD PTR [rbx+rax*8]  # Load 1 item
162e8:│  mulsd  xmm0,QWORD PTR [rsi+rax*8]  # Load 1 item
162ed:│  addsd  xmm1,xmm0                   # Main bottleneck (accumulation)
162f1:│  cmp    rdi,rax
162f4:│↓ je     163f8                       # Bound checking conditional jump
162fa:│  cmp    rdx,rax
162fd:│↓ je     16308                       # Bound checking conditional jump
162ff:│  add    rax,0x1
16303:├──cmp    rcx,rax
16306:└──jne    162e3

Once optimized, the result is:

13720:┌─→vmovupd      ymm3,YMMWORD PTR [r13+rax*1+0x0]    # Load 4 items
13727:│  vmulpd       ymm0,ymm3,YMMWORD PTR [rcx+rax*1]   # Load 4 items
1372c:│  add          rax,0x20
13730:│  vaddpd       ymm1,ymm1,ymm0        # Still a bottleneck (but better)
13734:├──cmp          rdx,rax
13737:└──jne          13720

The result += ... operation is still the bottleneck in the optimized version but this is much better since the loop work on 4 items at once. To remove the bottleneck, the loop must be partially unrolled. However, GCC (which is the default compiler on my machine) is not able to do that properly (even when ask to using -funrol-loops (due to a loop-carried dependency). This is why OpenBLAS should be a bit faster than the code produced by GCC.

Hopefully, Clang is able to do that by default. Here is the code produced by Clang:

59e0:┌─→vmovupd      ymm4,YMMWORD PTR [rax+rdi*8]       # load 4 items
59e5:│  vmovupd      ymm5,YMMWORD PTR [rax+rdi*8+0x20]  # load 4 items
59eb:│  vmovupd      ymm6,YMMWORD PTR [rax+rdi*8+0x40]  # load 4 items
59f1:│  vmovupd      ymm7,YMMWORD PTR [rax+rdi*8+0x60]  # load 4 items
59f7:│  vmulpd       ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
59fc:│  vaddpd       ymm0,ymm4,ymm0
5a00:│  vmulpd       ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
5a06:│  vaddpd       ymm1,ymm4,ymm1
5a0a:│  vmulpd       ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
5a10:│  vmulpd       ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
5a16:│  vaddpd       ymm2,ymm4,ymm2
5a1a:│  vaddpd       ymm3,ymm5,ymm3
5a1e:│  add          rdi,0x10
5a22:├──cmp          rsi,rdi
5a25:└──jne          59e0

The code is not optimal (because it should unroll the loop at least 6 times due to the latency of the vaddpd instruction), but it is very good.

Answered By – Jérôme Richard

Answer Checked By – Timothy Miller (BugsFixing Admin)

Leave a Reply

Your email address will not be published. Required fields are marked *