[SOLVED] How to efficiently scan 2 bit masks alternating each iteration

Issue

Given are 2 bitmasks, that should be accessed alternating (0,1,0,1…). I try to get a runtime efficient solution, but find no better way then following example.

uint32_t mask[2] { ... };
uint8_t mask_index = 0;
uint32_t f = _tzcnt_u32(mask[mask_index]);
while (f < 32) {
    // element adding to result vector removed, since not relevant for question itself
    mask[0] >>= f + 1;
    mask[1] >>= f + 1;
    mask_index ^= 1;
    f = _tzcnt_u32(mask[mask_index]);
}

ASM output (MSVC, x64) seems blown up pretty much.

inc         r9  
add         r9,rcx  
mov         eax,esi  
mov         qword ptr [rdi+rax*8],r9  
inc         esi  
lea         rax,[rcx+1]  
shrx        r11d,r11d,eax  
mov         dword ptr [rbp],r11d  
shrx        r8d,r8d,eax  
mov         dword ptr [rbp+4],r8d  
xor         r10b,1  
movsx       rax,r10b  
tzcnt       ecx,dword ptr [rbp+rax*4]  
mov         ecx,ecx  
cmp         rcx,20h  
jb          main+240h (07FF632862FD0h)  
cmp         r9,20h  
jb          main+230h (07FF632862FC0h) 

Has someone an advice?

(This is is a followup to Solve loop data dependency with SIMD – finding transitions between -1 and +1 in an int8_t array of sgn values using SIMD to create the bitmasks)

Update

I wonder if a potential solution could make use of SIMD by loading chunks of both bit streams into a register (AVX2 in my case) like this:

|m0[0]|m1[0]|m0[1]|m1[1]|m0[2]|m1[2]|m0[n+1]|m1[n+1]|

or

1 register with chunks per stream

|m0[0]|m0[1]|m0[2]|m0[n+1]|

|m1[0]|m1[1]|m1[2]|m1[n+1]|

or split the stream in chunks of same size and deal with as many lanes fit into the register at once. Let’s assume we have 256*10 elements which might end up in 10 iterations like this:
|m0[0]|m0[256]|m0[512]|…|
|m1[0]|m1[256]|m1[512]|…|
and deal with the join separately

Not sure if this might be a way to achieve more iterations per cycle and limit the need of horizontal bitscans, shift/clear op’s and avoid branches.

Solution

This is quite hard to optimize this loop. The main issue is that each iteration of the loop is dependent of the previous one and even instructions in the loops are dependent. This creates a long nearly sequential chain of instruction to be executed. As a result the processor cannot execute this efficiently. In addition, some instructions in this chain have a quite high latency: tzcnt has a 3-cycle latency on Intel processors and L1 load/store have a 3 cycle latency.

One solution is work directly with registers instead of an array with indirect accesses so to reduce the length of the chain and especially instruction with the highest latency. This can be done by unrolling the loop twice and splitting the problem in two different ones:

uint32_t m0 = mask[0];
uint32_t m1 = mask[1];
uint8_t mask_index = 0;

if(mask_index == 0) {
    uint32_t f = _tzcnt_u32(m0);

    while (f < 32) {
        m1 >>= f + 1;
        m0 >>= f + 1;
        f = _tzcnt_u32(m1);

        if(f >= 32)
            break;

        m0 >>= f + 1;
        m1 >>= f + 1;
        f = _tzcnt_u32(m0);
    }
}
else {
    uint32_t f = _tzcnt_u32(m1);

    while (f < 32) {
        m0 >>= f + 1;
        m1 >>= f + 1;
        f = _tzcnt_u32(m1);

        if(f >= 32)
            break;

        m0 >>= f + 1;
        m1 >>= f + 1;
        f = _tzcnt_u32(m0);
    }
}

// If mask is needed, m0 and m1 need to be stored back in mask.

This should be a bit faster, especially because a smaller critical path but also because the two shifts can be executed in parallel. Here is the resulting assembly code:

$loop:
        inc     ecx
        shr     edx, cl
        shr     eax, cl
        tzcnt   ecx, edx

        cmp     ecx, 32
        jae     SHORT $end_loop

        inc     ecx
        shr     eax, cl
        shr     edx, cl
        tzcnt   ecx, eax

        cmp     ecx, 32
        jb      SHORT $loop

Note that modern x86 processors can fuse the instructions cmp+jae and cmp+jb and the branch prediction can assume the loop will continue so it just miss-predict the last conditional jump. On Intel processors, the critical path is composed of a 1-cycle latency inc, a 1-cycle latency shr, a 3-cycle latency tzcnt resulting in a 5-cycle per round (1 round = 1 iteration of the initial loop). On AMD Zen-like processors, it is 1+1+2=4 cycles which is very good. Optimizing this further appears to be very challenging.

One possible optimization could be to use a lookup table so to compute the lower bits of m0 and m1 in bigger steps. However, a lookup table fetch has a 3-cycle latency, may cause expensive cache misses in practice, takes more memory and make the code significantly more complex since the number of trailing 0 bits can be quite big (eg. 28 bits). Thus, I am not sure this is a good idea although it certainly worth trying.

Answered By – Jérôme Richard

Answer Checked By – David Marino (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *