[SOLVED] numpy einsum/tensordot with shared non-contracted axis

Issue

Suppose I have two arrays:

import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)

and want to sum over the last 3 axes, and keep the shared axis. The output dimension should be (32,6,6,20,128). Notice here the axis with 20 is shared in both a and b. Let’s call this axis the "group" axis.

I have two methods for this task:
The first one is just a simple einsum:

def method1(a, b):
    return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True)  # output shape:(32,6,6,20,128)

In the second method I loop through group dimension and use einsum/tensordot to compute the result for each group dimension, then stack the results:

def method2(a, b):
    result = []
    for g in range(b.shape[0]): # loop through each group dimension
        # result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
        result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True))  # output shape:(32,6,6,128)
    return np.stack(result, axis=-2)  # output shape:(32,6,6,20,128)

here’s the timing for both methods in my jupyter notebook:
enter image description here
we can see the second method with a loop is faster than the first method.

My question is:

  1. How come method1 is that much slower? It doesn’t compute more things.
  2. Is there a more efficient way without using loops? (I’m a bit reluctant to use loops because they are slow in python)

Thanks for any help!

Solution

As pointed out by @Murali in the comments, method1 is not very efficient because it does not succeed to use a BLAS calls as opposed to method2 which does. In fact, np.einsum is quite good in method1 since it compute the result sequentially while method2 mostly runs in parallel thanks to OpenBLAS (used by Numpy on most machines). That being said, method2 is sub-optimal since it does not fully use the available cores (parts of the computation are done sequentially) and appear not to use the cache efficiently. On my 6-core machine, it barely use 50% of all the cores.


Faster implementation

One solution to speed up this computation is to write an highly-optimized Numba parallel code for this.

First of all, a semi-naive implementation is to use many for loops to compute the Einstein summation and reshape the input/output arrays so Numba can better optimize the code (eg. unrolling, use of SIMD instructions). Here is the result:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)

    for NHW in range(sN*sH*sW):
        for g in range(sg):
            for o in range(so):
                s = 0.0

                # Reduction
                for ihw in range(si*sh*sw):
                    s += ra[NHW, g, ihw] * rb[g, o, ihw]

                out[NHW, g, o] = s

    return out.reshape((sN, sH, sW, sg, so))

Note that the input array are assumed to be contiguous. If this is not the case, please consider performing a copy (which is cheap compared to the computation).

While the above code works, it is far from being efficient. Here are some improvements that can be performed:

  • run the outermost NHW loop in parallel;
  • use the Numba flag fastmath=True. This flag is unsafe if the input data contains special values like NaN or +inf/-inf. However, this flag help compiler to generate a much faster code using SIMD instructions (this is not possible otherwise since IEEE-754 floating-point operations are not associative);
  • swap the NHW-based loop and g-based loop results in better performance since it improves cache-locality (rb is more likely to fit in the last-level cache of mainstream CPUs whereas it would likely in fetched from the RAM otherwise);
  • make use of register blocking so to saturate better SIMD computing units of the processor and reduce the pressure on the memory hierarchy;
  • make use of tiling by splitting the o-based loop so rb can almost fully be read from lower-level caches (eg. L1 or L2).

All these improvements except the last one are implemented in the following code:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)

    for g in range(sg):
        for k in nb.prange((sN*sH*sW)//2):
            NHW = k*2
            so_vect_max = (so // 4) * 4

            for o in range(0, so_vect_max, 4):
                s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0

                # Useful since Numba does not optimize well the following loop otherwise
                ra_row0 = ra[NHW+0, g, :]
                ra_row1 = ra[NHW+1, g, :]
                rb_row0 = rb[g, o+0, :]
                rb_row1 = rb[g, o+1, :]
                rb_row2 = rb[g, o+2, :]
                rb_row3 = rb[g, o+3, :]

                # Highly-optimized reduction using register blocking
                for ihw in range(si*sh*sw):
                    ra_0 = ra_row0[ihw]
                    ra_1 = ra_row1[ihw]
                    rb_0 = rb_row0[ihw]
                    rb_1 = rb_row1[ihw]
                    rb_2 = rb_row2[ihw]
                    rb_3 = rb_row3[ihw]
                    s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
                    s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
                    s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
                    s12 += ra_1 * rb_2; s13 += ra_1 * rb_3

                out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
                out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
                out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
                out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13

            # Remaining part for `o`
            for o in range(so_vect_max, so):
                for ihw in range(si*sh*sw):
                    out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
                    out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]

        # Remaining part for `k`
        if (sN*sH*sW) % 2 == 1:
            k = sN*sH*sW - 1
            for o in range(so):
                for ihw in range(si*sh*sw):
                    out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]


    return out.reshape((sN, sH, sW, sg, so))

This code is much more complex and uglier but also far more efficient. I did not implemented the tiling optimization since it would make the code even less readable. However, it should results in a significantly faster code on many-core processors (especially the ones with a small L2/L3 cache).


Performance results

Here are performance results on my i5-9600KF 6-core processor:

method1:              816 ms
method2:              104 ms
method3:               40 ms
Theoretical optimal:    9 ms   (optimistic lower bound)

The code is about 2.7 faster than method2. There is a room for improvements since the optimal time is about 4 time better than method3.

The main reason why Numba does not generate a fast code comes from the underlying JIT which fail to efficiently vectorize the loop. Implementing the tiling strategy should slightly improves the execution time very close to the optimal one. The tiling strategy is critical for much bigger arrays. This is especially true if so is much bigger.

If you want a faster implementation you certainly need to write a C/C++ native code using directly SIMD instrinsics (which are unfortunately not portable) or a SIMD library (eg. XSIMD).

If you want an even faster implementation, then you need to use a faster hardware (with more cores) or a more dedicated one. Server-based GPUs (ie. not the one of personal computers) not should be able to speed up a lot such a computation since your input is small, clearly compute-bound and massively makes use of FMA floating-point operations. A first start is to try cupy.einsum.


Under the hood: low-level analysis

In order to understand why method1 is not faster, I checked the executed code. Here is the main loop:

1a0:┌─→; Part of the reduction (see below)
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1000]
    │  
    │  ; Decrement the number of loop cycle
    │  sub        r9,0x8 
    │  
    │  ; Prefetch items so to reduce the impact 
    │  ; of the latency of reading from the RAM.
    │  prefetcht0 BYTE PTR [r8]
    │  prefetcht0 BYTE PTR [rdi]
    │  
    │  ; Part of the reduction (see below)
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1000]
    │  
    │  ; Increment iterator for the two arrays
    │  add        rdi,0x40 
    │  add        r8,0x40 
    │  
    │  ; Main computational part: 
    │  ; reduction using add+mul SSE2 instructions
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1030]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1030]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1020]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1020]
    │  addpd      xmm0,xmm1                     <--- Slow
    │  movapd     xmm1,XMMWORD PTR [rdi-0x1010]
    │  mulpd      xmm1,XMMWORD PTR [r8-0x1010]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  
    │  ; Is the loop over? 
    │  ; If not, jump to the beginning of the loop.
    ├──cmp        r9,0x7 
    └──jg         1a0

It turns out that Numpy use the SSE2 instruction set (which is available on all x86-64 processors). However, my machine, like almost all relatively recent processor support the AVX instruction set which can compute twice more items at once per instruction. My machine also support fuse-multiply add instructions (FMA) that are twice faster in this case. Moreover, the loop is clearly bounded by the addpd which accumulate the result in mostly the same register. The processor cannot execute them efficiently since an addpd takes few cycle of latency and up to two can be executed at the same time on modern x86-64 processors (which is not possible here since only 1 intruction can perform the accumulation in xmm1 at a time).

Here is the executed code of the main computational part of method2 (dgemm call of OpenBLAS):

6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi-0x80]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi-0x60]
     │  vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi-0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi-0x20]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi+0x20]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi+0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi+0x60]
     │  add          rsi,0x40
     │  add          rdi,0x100
     ├──dec          rax
     └──jne          6a40

This loop is far more optimized: it makes use of the AVX instruction set as well as the FMA one (ie. vfmadd231pd instructions). Furthermore, the loop is better unrolled and there is not latency/dependency issue like in the Numpy code. However, while this loop is highly-efficient, the cores are not efficiently used due to some sequential checks done in Numpy and a sequential copy performed in OpenBLAS. Moreover, I am not sure the loop makes an efficient use of the cache in this case since a lot of read/writes are performed in RAM on my machine. Indeed, the RAM throughput about 15 GiB/s (over 35~40 GiB/s) due to many cache misses while the thoughput of method3 is 6 GiB/s (so more work is done in the cache) with a significantly faster execution.

Here is the executed code of the main computational part of method3:

.LBB0_5:
    vorpd   2880(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm2
    vmovupd %ymm2, 3040(%rsp)
    vorpd   2848(%rsp), %ymm8, %ymm1
    vpcmpeqd    %ymm2, %ymm2, %ymm2
    vgatherqpd  %ymm2, (%rsi,%ymm1,8), %ymm3
    vmovupd %ymm3, 3104(%rsp)
    vorpd   2912(%rsp), %ymm8, %ymm2
    vpcmpeqd    %ymm3, %ymm3, %ymm3
    vgatherqpd  %ymm3, (%rsi,%ymm2,8), %ymm4
    vmovupd %ymm4, 3136(%rsp)
    vorpd   2816(%rsp), %ymm8, %ymm3
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm3,8), %ymm5
    vmovupd %ymm5, 3808(%rsp)
    vorpd   2784(%rsp), %ymm8, %ymm9
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm9,8), %ymm5
    vmovupd %ymm5, 3840(%rsp)
    vorpd   2752(%rsp), %ymm8, %ymm10
    vpcmpeqd    %ymm4, %ymm4, %ymm4
    vgatherqpd  %ymm4, (%rsi,%ymm10,8), %ymm5
    vmovupd %ymm5, 3872(%rsp)
    vpaddq  2944(%rsp), %ymm8, %ymm4
    vorpd   2720(%rsp), %ymm8, %ymm11
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rsi,%ymm11,8), %ymm5
    vmovupd %ymm5, 3904(%rsp)
    vpcmpeqd    %ymm13, %ymm13, %ymm13
    vgatherqpd  %ymm13, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3552(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm1,8), %ymm5
    vmovupd %ymm5, 3616(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm2,8), %ymm1
    vmovupd %ymm1, 3648(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm3,8), %ymm1
    vmovupd %ymm1, 3680(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm9,8), %ymm1
    vmovupd %ymm1, 3712(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm10,8), %ymm1
    vmovupd %ymm1, 3744(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm11,8), %ymm1
    vmovupd %ymm1, 3776(%rsp)
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rsi,%ymm4,8), %ymm6
    vpcmpeqd    %ymm0, %ymm0, %ymm0
    vgatherqpd  %ymm0, (%rdx,%ymm4,8), %ymm3
    vpaddq  2688(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm7
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3360(%rsp)
    vpaddq  2656(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm13
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3392(%rsp)
    vpaddq  2624(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm15
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3424(%rsp)
    vpaddq  2592(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm9
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3456(%rsp)
    vpaddq  2560(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm14
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3488(%rsp)
    vpaddq  2528(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm11
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3520(%rsp)
    vpaddq  2496(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm10
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3584(%rsp)
    vpaddq  2464(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2432(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm12
    vpaddq  2400(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3168(%rsp)
    vpaddq  2368(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3200(%rsp)
    vpaddq  2336(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3232(%rsp)
    vpaddq  2304(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3264(%rsp)
    vpaddq  2272(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3296(%rsp)
    vpaddq  2240(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd %ymm4, 3328(%rsp)
    vpaddq  2208(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vpaddq  2176(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 2976(%rsp)
    vpaddq  2144(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3008(%rsp)
    vpaddq  2112(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
    vmovupd %ymm5, 3072(%rsp)
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rsi,%ymm8,8), %ymm0
    vpcmpeqd    %ymm5, %ymm5, %ymm5
    vgatherqpd  %ymm5, (%rdx,%ymm8,8), %ymm1
    vmovupd 768(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm1, %ymm5
    vmovupd %ymm5, 768(%rsp)
    vmovupd 32(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm3, %ymm5
    vmovupd %ymm5, 32(%rsp)
    vmovupd 1024(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm2, %ymm5
    vmovupd %ymm5, 1024(%rsp)
    vmovupd 1280(%rsp), %ymm5
    vfmadd231pd %ymm0, %ymm4, %ymm5
    vmovupd %ymm5, 1280(%rsp)
    vmovupd 1344(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1344(%rsp)
    vmovupd 480(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 480(%rsp)
    vmovupd 1600(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm6, %ymm0
    vmovupd %ymm0, 1600(%rsp)
    vmovupd 1856(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm6, %ymm0
    vmovupd %ymm0, 1856(%rsp)
    vpaddq  2080(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vpaddq  2048(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
    vmovupd 800(%rsp), %ymm0
    vmovupd 3552(%rsp), %ymm1
    vmovupd 3040(%rsp), %ymm3
    vfmadd231pd %ymm3, %ymm1, %ymm0
    vmovupd %ymm0, 800(%rsp)
    vmovupd 64(%rsp), %ymm0
    vmovupd 3360(%rsp), %ymm5
    vfmadd231pd %ymm3, %ymm5, %ymm0
    vmovupd %ymm0, 64(%rsp)
    vmovupd 1056(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm12, %ymm0
    vmovupd %ymm0, 1056(%rsp)
    vmovupd 288(%rsp), %ymm0
    vmovupd 2976(%rsp), %ymm6
    vfmadd231pd %ymm3, %ymm6, %ymm0
    vmovupd %ymm0, 288(%rsp)
    vmovupd 1376(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1376(%rsp)
    vmovupd 512(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm7, %ymm0
    vmovupd %ymm0, 512(%rsp)
    vmovupd 1632(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm7, %ymm0
    vmovupd %ymm0, 1632(%rsp)
    vmovupd 1888(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1888(%rsp)
    vmovupd 832(%rsp), %ymm0
    vmovupd 3616(%rsp), %ymm1
    vmovupd 3104(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 832(%rsp)
    vmovupd 96(%rsp), %ymm0
    vmovupd 3392(%rsp), %ymm3
    vfmadd231pd %ymm6, %ymm3, %ymm0
    vmovupd %ymm0, 96(%rsp)
    vmovupd 1088(%rsp), %ymm0
    vmovupd 3168(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 1088(%rsp)
    vmovupd 320(%rsp), %ymm0
    vmovupd 3008(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 320(%rsp)
    vmovupd 1408(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm13, %ymm0
    vmovupd %ymm0, 1408(%rsp)
    vmovupd 544(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm13, %ymm0
    vmovupd %ymm0, 544(%rsp)
    vmovupd 1664(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm13, %ymm0
    vmovupd %ymm0, 1664(%rsp)
    vmovupd 1920(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm13, %ymm0
    vmovupd %ymm0, 1920(%rsp)
    vpaddq  2016(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm3
    vmovupd 864(%rsp), %ymm0
    vmovupd 3648(%rsp), %ymm1
    vmovupd 3136(%rsp), %ymm6
    vfmadd231pd %ymm6, %ymm1, %ymm0
    vmovupd %ymm0, 864(%rsp)
    vmovupd 128(%rsp), %ymm0
    vmovupd 3424(%rsp), %ymm5
    vfmadd231pd %ymm6, %ymm5, %ymm0
    vmovupd %ymm0, 128(%rsp)
    vmovupd 1120(%rsp), %ymm0
    vmovupd 3200(%rsp), %ymm7
    vfmadd231pd %ymm6, %ymm7, %ymm0
    vmovupd %ymm0, 1120(%rsp)
    vmovupd 352(%rsp), %ymm0
    vmovupd 3072(%rsp), %ymm12
    vfmadd231pd %ymm6, %ymm12, %ymm0
    vmovupd %ymm0, 352(%rsp)
    vmovupd 1440(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm15, %ymm0
    vmovupd %ymm0, 1440(%rsp)
    vmovupd 576(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm15, %ymm0
    vmovupd %ymm0, 576(%rsp)
    vmovupd 1696(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm15, %ymm0
    vmovupd %ymm0, 1696(%rsp)
    vmovupd 736(%rsp), %ymm0
    vfmadd231pd %ymm12, %ymm15, %ymm0
    vmovupd %ymm0, 736(%rsp)
    vmovupd 896(%rsp), %ymm0
    vmovupd 3808(%rsp), %ymm1
    vmovupd 3680(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 896(%rsp)
    vmovupd 160(%rsp), %ymm0
    vmovupd 3456(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 160(%rsp)
    vmovupd 1152(%rsp), %ymm0
    vmovupd 3232(%rsp), %ymm7
    vfmadd231pd %ymm1, %ymm7, %ymm0
    vmovupd %ymm0, 1152(%rsp)
    vmovupd 384(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 384(%rsp)
    vmovupd 1472(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm9, %ymm0
    vmovupd %ymm0, 1472(%rsp)
    vmovupd 608(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm9, %ymm0
    vmovupd %ymm0, 608(%rsp)
    vmovupd 1728(%rsp), %ymm0
    vfmadd231pd %ymm7, %ymm9, %ymm0
    vmovupd %ymm0, 1728(%rsp)
    vmovupd -128(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm9, %ymm0
    vmovupd %ymm0, -128(%rsp)
    vmovupd 928(%rsp), %ymm0
    vmovupd 3840(%rsp), %ymm1
    vmovupd 3712(%rsp), %ymm2
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 928(%rsp)
    vmovupd 192(%rsp), %ymm0
    vmovupd 3488(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 192(%rsp)
    vmovupd 1184(%rsp), %ymm0
    vmovupd 3264(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1184(%rsp)
    vmovupd 416(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 416(%rsp)
    vmovupd 1504(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm14, %ymm0
    vmovupd %ymm0, 1504(%rsp)
    vmovupd 640(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm14, %ymm0
    vmovupd %ymm0, 640(%rsp)
    vmovupd 1760(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm14, %ymm0
    vmovupd %ymm0, 1760(%rsp)
    vmovupd -96(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm14, %ymm0
    vmovupd %ymm0, -96(%rsp)
    vpaddq  1984(%rsp), %ymm8, %ymm0
    vpcmpeqd    %ymm1, %ymm1, %ymm1
    vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
    vmovupd 960(%rsp), %ymm0
    vmovupd 3872(%rsp), %ymm1
    vmovupd 3744(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 960(%rsp)
    vmovupd 224(%rsp), %ymm0
    vmovupd 3520(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 224(%rsp)
    vmovupd 1216(%rsp), %ymm0
    vmovupd 3296(%rsp), %ymm6
    vfmadd231pd %ymm1, %ymm6, %ymm0
    vmovupd %ymm0, 1216(%rsp)
    vmovupd 448(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 448(%rsp)
    vmovupd 1536(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm11, %ymm0
    vmovupd %ymm0, 1536(%rsp)
    vmovupd 672(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm11, %ymm0
    vmovupd %ymm0, 672(%rsp)
    vmovupd 1792(%rsp), %ymm0
    vfmadd231pd %ymm6, %ymm11, %ymm0
    vmovupd %ymm0, 1792(%rsp)
    vmovupd -64(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm11, %ymm0
    vmovupd %ymm0, -64(%rsp)
    vmovupd 992(%rsp), %ymm0
    vmovupd 3904(%rsp), %ymm1
    vmovupd 3776(%rsp), %ymm3
    vfmadd231pd %ymm1, %ymm3, %ymm0
    vmovupd %ymm0, 992(%rsp)
    vmovupd 256(%rsp), %ymm0
    vmovupd 3584(%rsp), %ymm4
    vfmadd231pd %ymm1, %ymm4, %ymm0
    vmovupd %ymm0, 256(%rsp)
    vmovupd 1248(%rsp), %ymm0
    vmovupd 3328(%rsp), %ymm5
    vfmadd231pd %ymm1, %ymm5, %ymm0
    vmovupd %ymm0, 1248(%rsp)
    vmovupd 1312(%rsp), %ymm0
    vfmadd231pd %ymm1, %ymm2, %ymm0
    vmovupd %ymm0, 1312(%rsp)
    vmovupd 1568(%rsp), %ymm0
    vfmadd231pd %ymm3, %ymm10, %ymm0
    vmovupd %ymm0, 1568(%rsp)
    vmovupd 704(%rsp), %ymm0
    vfmadd231pd %ymm4, %ymm10, %ymm0
    vmovupd %ymm0, 704(%rsp)
    vmovupd 1824(%rsp), %ymm0
    vfmadd231pd %ymm5, %ymm10, %ymm0
    vmovupd %ymm0, 1824(%rsp)
    vmovupd -32(%rsp), %ymm0
    vfmadd231pd %ymm2, %ymm10, %ymm0
    vmovupd %ymm0, -32(%rsp)
    vpaddq  1952(%rsp), %ymm8, %ymm8
    addq    $-4, %rcx
    jne .LBB0_5

The loop is huge and is clearly not vectorized properly: there is a lot of completely useless instructions and loads from memory appear not to be contiguous (see vgatherqpd). Numba does not generate a good code since the underlying JIT (LLVM-Lite) fail to vectorize efficiently the code. In fact, I found out that a similar C++ code is badly vectorized by Clang 13.0 on a simplified example (GCC and ICC also fail on a more complex code) while an hand-written SIMD implementation works much better. It look like a bug of the optimizer or at least a missed optimization. This is why the Numba code is much slower than the optimal code. That being said, this implementation makes a quite efficient use of the cache and is properly multithreaded.

I also found out that the BLAS code is faster on Linux than Windows on my machine (with default packages coming from PIP and the same Numpy at version 1.20.3). Thus, the gap is closer between method2 and method3 but the later is still a significantly faster.

Answered By – Jérôme Richard

Answer Checked By – Clifford M. (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *