# [SOLVED] numpy einsum/tensordot with shared non-contracted axis

## Issue

Suppose I have two arrays:

``````import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)
``````

and want to sum over the last 3 axes, and keep the shared axis. The output dimension should be `(32,6,6,20,128)`. Notice here the axis with 20 is shared in both `a` and `b`. Let’s call this axis the "group" axis.

I have two methods for this task:
The first one is just a simple `einsum`:

``````def method1(a, b):
return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True)  # output shape:(32,6,6,20,128)
``````

In the second method I loop through group dimension and use `einsum`/`tensordot` to compute the result for each group dimension, then stack the results:

``````def method2(a, b):
result = []
for g in range(b.shape[0]): # loop through each group dimension
# result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True))  # output shape:(32,6,6,128)
return np.stack(result, axis=-2)  # output shape:(32,6,6,20,128)
``````

here’s the timing for both methods in my jupyter notebook:

we can see the second method with a loop is faster than the first method.

My question is:

1. How come method1 is that much slower? It doesn’t compute more things.
2. Is there a more efficient way without using loops? (I’m a bit reluctant to use loops because they are slow in python)

Thanks for any help!

## Solution

As pointed out by @Murali in the comments, `method1` is not very efficient because it does not succeed to use a BLAS calls as opposed to `method2` which does. In fact, `np.einsum` is quite good in `method1` since it compute the result sequentially while `method2` mostly runs in parallel thanks to OpenBLAS (used by Numpy on most machines). That being said, `method2` is sub-optimal since it does not fully use the available cores (parts of the computation are done sequentially) and appear not to use the cache efficiently. On my 6-core machine, it barely use 50% of all the cores.

## Faster implementation

One solution to speed up this computation is to write an highly-optimized Numba parallel code for this.

First of all, a semi-naive implementation is to use many for loops to compute the Einstein summation and reshape the input/output arrays so Numba can better optimize the code (eg. unrolling, use of SIMD instructions). Here is the result:

``````@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)

ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)

for NHW in range(sN*sH*sW):
for g in range(sg):
for o in range(so):
s = 0.0

# Reduction
for ihw in range(si*sh*sw):
s += ra[NHW, g, ihw] * rb[g, o, ihw]

out[NHW, g, o] = s

return out.reshape((sN, sH, sW, sg, so))
``````

Note that the input array are assumed to be contiguous. If this is not the case, please consider performing a copy (which is cheap compared to the computation).

While the above code works, it is far from being efficient. Here are some improvements that can be performed:

• run the outermost `NHW` loop in parallel;
• use the Numba flag `fastmath=True`. This flag is unsafe if the input data contains special values like NaN or +inf/-inf. However, this flag help compiler to generate a much faster code using SIMD instructions (this is not possible otherwise since IEEE-754 floating-point operations are not associative);
• swap the `NHW`-based loop and `g`-based loop results in better performance since it improves cache-locality (`rb` is more likely to fit in the last-level cache of mainstream CPUs whereas it would likely in fetched from the RAM otherwise);
• make use of register blocking so to saturate better SIMD computing units of the processor and reduce the pressure on the memory hierarchy;
• make use of tiling by splitting the `o`-based loop so `rb` can almost fully be read from lower-level caches (eg. L1 or L2).

All these improvements except the last one are implemented in the following code:

``````@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)

ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)

for g in range(sg):
for k in nb.prange((sN*sH*sW)//2):
NHW = k*2
so_vect_max = (so // 4) * 4

for o in range(0, so_vect_max, 4):
s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0

# Useful since Numba does not optimize well the following loop otherwise
ra_row0 = ra[NHW+0, g, :]
ra_row1 = ra[NHW+1, g, :]
rb_row0 = rb[g, o+0, :]
rb_row1 = rb[g, o+1, :]
rb_row2 = rb[g, o+2, :]
rb_row3 = rb[g, o+3, :]

# Highly-optimized reduction using register blocking
for ihw in range(si*sh*sw):
ra_0 = ra_row0[ihw]
ra_1 = ra_row1[ihw]
rb_0 = rb_row0[ihw]
rb_1 = rb_row1[ihw]
rb_2 = rb_row2[ihw]
rb_3 = rb_row3[ihw]
s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
s12 += ra_1 * rb_2; s13 += ra_1 * rb_3

out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13

# Remaining part for `o`
for o in range(so_vect_max, so):
for ihw in range(si*sh*sw):
out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]

# Remaining part for `k`
if (sN*sH*sW) % 2 == 1:
k = sN*sH*sW - 1
for o in range(so):
for ihw in range(si*sh*sw):
out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]

return out.reshape((sN, sH, sW, sg, so))
``````

This code is much more complex and uglier but also far more efficient. I did not implemented the tiling optimization since it would make the code even less readable. However, it should results in a significantly faster code on many-core processors (especially the ones with a small L2/L3 cache).

## Performance results

Here are performance results on my i5-9600KF 6-core processor:

``````method1:              816 ms
method2:              104 ms
method3:               40 ms
Theoretical optimal:    9 ms   (optimistic lower bound)
``````

The code is about 2.7 faster than `method2`. There is a room for improvements since the optimal time is about 4 time better than `method3`.

The main reason why Numba does not generate a fast code comes from the underlying JIT which fail to efficiently vectorize the loop. Implementing the tiling strategy should slightly improves the execution time very close to the optimal one. The tiling strategy is critical for much bigger arrays. This is especially true if `so` is much bigger.

If you want a faster implementation you certainly need to write a C/C++ native code using directly SIMD instrinsics (which are unfortunately not portable) or a SIMD library (eg. XSIMD).

If you want an even faster implementation, then you need to use a faster hardware (with more cores) or a more dedicated one. Server-based GPUs (ie. not the one of personal computers) not should be able to speed up a lot such a computation since your input is small, clearly compute-bound and massively makes use of FMA floating-point operations. A first start is to try `cupy.einsum`.

## Under the hood: low-level analysis

In order to understand why `method1` is not faster, I checked the executed code. Here is the main loop:

``````1a0:┌─→; Part of the reduction (see below)
│  movapd     xmm0,XMMWORD PTR [rdi-0x1000]
│
│  ; Decrement the number of loop cycle
│  sub        r9,0x8
│
│  ; Prefetch items so to reduce the impact
│  ; of the latency of reading from the RAM.
│  prefetcht0 BYTE PTR [r8]
│  prefetcht0 BYTE PTR [rdi]
│
│  ; Part of the reduction (see below)
│  mulpd      xmm0,XMMWORD PTR [r8-0x1000]
│
│  ; Increment iterator for the two arrays
│
│  ; Main computational part:
│  ; reduction using add+mul SSE2 instructions
│  addpd      xmm1,xmm0                     <--- Slow
│  movapd     xmm0,XMMWORD PTR [rdi-0x1030]
│  mulpd      xmm0,XMMWORD PTR [r8-0x1030]
│  addpd      xmm1,xmm0                     <--- Slow
│  movapd     xmm0,XMMWORD PTR [rdi-0x1020]
│  mulpd      xmm0,XMMWORD PTR [r8-0x1020]
│  addpd      xmm0,xmm1                     <--- Slow
│  movapd     xmm1,XMMWORD PTR [rdi-0x1010]
│  mulpd      xmm1,XMMWORD PTR [r8-0x1010]
│  addpd      xmm1,xmm0                     <--- Slow
│
│  ; Is the loop over?
│  ; If not, jump to the beginning of the loop.
├──cmp        r9,0x7
└──jg         1a0
``````

It turns out that Numpy use the SSE2 instruction set (which is available on all x86-64 processors). However, my machine, like almost all relatively recent processor support the AVX instruction set which can compute twice more items at once per instruction. My machine also support fuse-multiply add instructions (FMA) that are twice faster in this case. Moreover, the loop is clearly bounded by the `addpd` which accumulate the result in mostly the same register. The processor cannot execute them efficiently since an `addpd` takes few cycle of latency and up to two can be executed at the same time on modern x86-64 processors (which is not possible here since only 1 intruction can perform the accumulation in `xmm1` at a time).

Here is the executed code of the main computational part of `method2` (`dgemm` call of OpenBLAS):

``````6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
│  vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
│  vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
│  vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
│  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi-0x80]
│  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi-0x60]
│  vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
│  vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
│  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi-0x40]
│  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi-0x20]
│  vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
│  vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
│  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi]
│  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi+0x20]
│  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi+0x40]
│  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi+0x60]
├──dec          rax
└──jne          6a40
``````

This loop is far more optimized: it makes use of the AVX instruction set as well as the FMA one (ie. `vfmadd231pd` instructions). Furthermore, the loop is better unrolled and there is not latency/dependency issue like in the Numpy code. However, while this loop is highly-efficient, the cores are not efficiently used due to some sequential checks done in Numpy and a sequential copy performed in OpenBLAS. Moreover, I am not sure the loop makes an efficient use of the cache in this case since a lot of read/writes are performed in RAM on my machine. Indeed, the RAM throughput about 15 GiB/s (over 35~40 GiB/s) due to many cache misses while the thoughput of `method3` is 6 GiB/s (so more work is done in the cache) with a significantly faster execution.

Here is the executed code of the main computational part of `method3`:

``````.LBB0_5:
vorpd   2880(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm2
vmovupd %ymm2, 3040(%rsp)
vorpd   2848(%rsp), %ymm8, %ymm1
vpcmpeqd    %ymm2, %ymm2, %ymm2
vgatherqpd  %ymm2, (%rsi,%ymm1,8), %ymm3
vmovupd %ymm3, 3104(%rsp)
vorpd   2912(%rsp), %ymm8, %ymm2
vpcmpeqd    %ymm3, %ymm3, %ymm3
vgatherqpd  %ymm3, (%rsi,%ymm2,8), %ymm4
vmovupd %ymm4, 3136(%rsp)
vorpd   2816(%rsp), %ymm8, %ymm3
vpcmpeqd    %ymm4, %ymm4, %ymm4
vgatherqpd  %ymm4, (%rsi,%ymm3,8), %ymm5
vmovupd %ymm5, 3808(%rsp)
vorpd   2784(%rsp), %ymm8, %ymm9
vpcmpeqd    %ymm4, %ymm4, %ymm4
vgatherqpd  %ymm4, (%rsi,%ymm9,8), %ymm5
vmovupd %ymm5, 3840(%rsp)
vorpd   2752(%rsp), %ymm8, %ymm10
vpcmpeqd    %ymm4, %ymm4, %ymm4
vgatherqpd  %ymm4, (%rsi,%ymm10,8), %ymm5
vmovupd %ymm5, 3872(%rsp)
vpaddq  2944(%rsp), %ymm8, %ymm4
vorpd   2720(%rsp), %ymm8, %ymm11
vpcmpeqd    %ymm13, %ymm13, %ymm13
vgatherqpd  %ymm13, (%rsi,%ymm11,8), %ymm5
vmovupd %ymm5, 3904(%rsp)
vpcmpeqd    %ymm13, %ymm13, %ymm13
vgatherqpd  %ymm13, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3552(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm1,8), %ymm5
vmovupd %ymm5, 3616(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm2,8), %ymm1
vmovupd %ymm1, 3648(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm3,8), %ymm1
vmovupd %ymm1, 3680(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm9,8), %ymm1
vmovupd %ymm1, 3712(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm10,8), %ymm1
vmovupd %ymm1, 3744(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm11,8), %ymm1
vmovupd %ymm1, 3776(%rsp)
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rsi,%ymm4,8), %ymm6
vpcmpeqd    %ymm0, %ymm0, %ymm0
vgatherqpd  %ymm0, (%rdx,%ymm4,8), %ymm3
vpaddq  2688(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm7
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3360(%rsp)
vpaddq  2656(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm13
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3392(%rsp)
vpaddq  2624(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm15
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3424(%rsp)
vpaddq  2592(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm9
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3456(%rsp)
vpaddq  2560(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm14
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3488(%rsp)
vpaddq  2528(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm11
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3520(%rsp)
vpaddq  2496(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm0,8), %ymm10
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3584(%rsp)
vpaddq  2464(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq  2432(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm12
vpaddq  2400(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3168(%rsp)
vpaddq  2368(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3200(%rsp)
vpaddq  2336(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3232(%rsp)
vpaddq  2304(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3264(%rsp)
vpaddq  2272(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3296(%rsp)
vpaddq  2240(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3328(%rsp)
vpaddq  2208(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vpaddq  2176(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 2976(%rsp)
vpaddq  2144(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3008(%rsp)
vpaddq  2112(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3072(%rsp)
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rsi,%ymm8,8), %ymm0
vpcmpeqd    %ymm5, %ymm5, %ymm5
vgatherqpd  %ymm5, (%rdx,%ymm8,8), %ymm1
vmovupd 768(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm1, %ymm5
vmovupd %ymm5, 768(%rsp)
vmovupd 32(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm3, %ymm5
vmovupd %ymm5, 32(%rsp)
vmovupd 1024(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm2, %ymm5
vmovupd %ymm5, 1024(%rsp)
vmovupd 1280(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm4, %ymm5
vmovupd %ymm5, 1280(%rsp)
vmovupd 1344(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1344(%rsp)
vmovupd 480(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 480(%rsp)
vmovupd 1600(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm6, %ymm0
vmovupd %ymm0, 1600(%rsp)
vmovupd 1856(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm6, %ymm0
vmovupd %ymm0, 1856(%rsp)
vpaddq  2080(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq  2048(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd 800(%rsp), %ymm0
vmovupd 3552(%rsp), %ymm1
vmovupd 3040(%rsp), %ymm3
vfmadd231pd %ymm3, %ymm1, %ymm0
vmovupd %ymm0, 800(%rsp)
vmovupd 64(%rsp), %ymm0
vmovupd 3360(%rsp), %ymm5
vfmadd231pd %ymm3, %ymm5, %ymm0
vmovupd %ymm0, 64(%rsp)
vmovupd 1056(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm12, %ymm0
vmovupd %ymm0, 1056(%rsp)
vmovupd 288(%rsp), %ymm0
vmovupd 2976(%rsp), %ymm6
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 288(%rsp)
vmovupd 1376(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1376(%rsp)
vmovupd 512(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm7, %ymm0
vmovupd %ymm0, 512(%rsp)
vmovupd 1632(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm7, %ymm0
vmovupd %ymm0, 1632(%rsp)
vmovupd 1888(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1888(%rsp)
vmovupd 832(%rsp), %ymm0
vmovupd 3616(%rsp), %ymm1
vmovupd 3104(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 832(%rsp)
vmovupd 96(%rsp), %ymm0
vmovupd 3392(%rsp), %ymm3
vfmadd231pd %ymm6, %ymm3, %ymm0
vmovupd %ymm0, 96(%rsp)
vmovupd 1088(%rsp), %ymm0
vmovupd 3168(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 1088(%rsp)
vmovupd 320(%rsp), %ymm0
vmovupd 3008(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 320(%rsp)
vmovupd 1408(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm13, %ymm0
vmovupd %ymm0, 1408(%rsp)
vmovupd 544(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm13, %ymm0
vmovupd %ymm0, 544(%rsp)
vmovupd 1664(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm13, %ymm0
vmovupd %ymm0, 1664(%rsp)
vmovupd 1920(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm13, %ymm0
vmovupd %ymm0, 1920(%rsp)
vpaddq  2016(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm3
vmovupd 864(%rsp), %ymm0
vmovupd 3648(%rsp), %ymm1
vmovupd 3136(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 864(%rsp)
vmovupd 128(%rsp), %ymm0
vmovupd 3424(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 128(%rsp)
vmovupd 1120(%rsp), %ymm0
vmovupd 3200(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1120(%rsp)
vmovupd 352(%rsp), %ymm0
vmovupd 3072(%rsp), %ymm12
vfmadd231pd %ymm6, %ymm12, %ymm0
vmovupd %ymm0, 352(%rsp)
vmovupd 1440(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm15, %ymm0
vmovupd %ymm0, 1440(%rsp)
vmovupd 576(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm15, %ymm0
vmovupd %ymm0, 576(%rsp)
vmovupd 1696(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm15, %ymm0
vmovupd %ymm0, 1696(%rsp)
vmovupd 736(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm15, %ymm0
vmovupd %ymm0, 736(%rsp)
vmovupd 896(%rsp), %ymm0
vmovupd 3808(%rsp), %ymm1
vmovupd 3680(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 896(%rsp)
vmovupd 160(%rsp), %ymm0
vmovupd 3456(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 160(%rsp)
vmovupd 1152(%rsp), %ymm0
vmovupd 3232(%rsp), %ymm7
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1152(%rsp)
vmovupd 384(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 384(%rsp)
vmovupd 1472(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm9, %ymm0
vmovupd %ymm0, 1472(%rsp)
vmovupd 608(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm9, %ymm0
vmovupd %ymm0, 608(%rsp)
vmovupd 1728(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm9, %ymm0
vmovupd %ymm0, 1728(%rsp)
vmovupd -128(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm9, %ymm0
vmovupd %ymm0, -128(%rsp)
vmovupd 928(%rsp), %ymm0
vmovupd 3840(%rsp), %ymm1
vmovupd 3712(%rsp), %ymm2
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 928(%rsp)
vmovupd 192(%rsp), %ymm0
vmovupd 3488(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 192(%rsp)
vmovupd 1184(%rsp), %ymm0
vmovupd 3264(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1184(%rsp)
vmovupd 416(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 416(%rsp)
vmovupd 1504(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm14, %ymm0
vmovupd %ymm0, 1504(%rsp)
vmovupd 640(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm14, %ymm0
vmovupd %ymm0, 640(%rsp)
vmovupd 1760(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm14, %ymm0
vmovupd %ymm0, 1760(%rsp)
vmovupd -96(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm14, %ymm0
vmovupd %ymm0, -96(%rsp)
vpaddq  1984(%rsp), %ymm8, %ymm0
vpcmpeqd    %ymm1, %ymm1, %ymm1
vgatherqpd  %ymm1, (%rdx,%ymm0,8), %ymm2
vmovupd 960(%rsp), %ymm0
vmovupd 3872(%rsp), %ymm1
vmovupd 3744(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 960(%rsp)
vmovupd 224(%rsp), %ymm0
vmovupd 3520(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 224(%rsp)
vmovupd 1216(%rsp), %ymm0
vmovupd 3296(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1216(%rsp)
vmovupd 448(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 448(%rsp)
vmovupd 1536(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm11, %ymm0
vmovupd %ymm0, 1536(%rsp)
vmovupd 672(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm11, %ymm0
vmovupd %ymm0, 672(%rsp)
vmovupd 1792(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm11, %ymm0
vmovupd %ymm0, 1792(%rsp)
vmovupd -64(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm11, %ymm0
vmovupd %ymm0, -64(%rsp)
vmovupd 992(%rsp), %ymm0
vmovupd 3904(%rsp), %ymm1
vmovupd 3776(%rsp), %ymm3
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 992(%rsp)
vmovupd 256(%rsp), %ymm0
vmovupd 3584(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 256(%rsp)
vmovupd 1248(%rsp), %ymm0
vmovupd 3328(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 1248(%rsp)
vmovupd 1312(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 1312(%rsp)
vmovupd 1568(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm10, %ymm0
vmovupd %ymm0, 1568(%rsp)
vmovupd 704(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm10, %ymm0
vmovupd %ymm0, 704(%rsp)
vmovupd 1824(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm10, %ymm0
vmovupd %ymm0, 1824(%rsp)
vmovupd -32(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm10, %ymm0
vmovupd %ymm0, -32(%rsp)
vpaddq  1952(%rsp), %ymm8, %ymm8
The loop is huge and is clearly not vectorized properly: there is a lot of completely useless instructions and loads from memory appear not to be contiguous (see `vgatherqpd`). Numba does not generate a good code since the underlying JIT (LLVM-Lite) fail to vectorize efficiently the code. In fact, I found out that a similar C++ code is badly vectorized by Clang 13.0 on a simplified example (GCC and ICC also fail on a more complex code) while an hand-written SIMD implementation works much better. It look like a bug of the optimizer or at least a missed optimization. This is why the Numba code is much slower than the optimal code. That being said, this implementation makes a quite efficient use of the cache and is properly multithreaded.
I also found out that the BLAS code is faster on Linux than Windows on my machine (with default packages coming from PIP and the same Numpy at version 1.20.3). Thus, the gap is closer between `method2` and `method3` but the later is still a significantly faster.