5

How can I create a __m128i having the n most significant bits set (in the entire vector)? I need this to mask portions of a buffer that are relevant for a computation. If possible, the solution should have no branches, but this seems hard to achieve

How can I do this ?

Paul R
  • 208,748
  • 37
  • 389
  • 560
user46317
  • 843
  • 1
  • 10
  • 16

3 Answers3

3

I'm adding this as a second answer and leaving the first answer for historical interest. It looks like you can do something more efficient with _mm_slli_epi64:

#include <emmintrin.h>
#include <stdio.h>

__m128i bit_mask(int n)
{
    __m128i v0 = _mm_set_epi64x(-1, -(n > 64)); // AND mask
    __m128i v1 = _mm_set_epi64x(-(n > 64), 0);  // OR mask
    __m128i v2 = _mm_slli_epi64(_mm_set1_epi64x(-1), (128 - n) & 63);
    v2 = _mm_and_si128(v2, v0);
    v2 = _mm_or_si128(v2, v1);
    return v2;
}

int main(int argc, char *argv[])
{
    int n = 36;

    if (argc > 1) n = atoi(argv[1]);

    printf("bit_mask(%3d) = %02vx\n", n, bit_mask(n));

    return 0;
}

Test:

$ gcc -Wall -msse2 sse_bit_mask.c
$ for n in 1 2 3 63 64 65 127 128 ; do ./a.out $n ; done
bit_mask(  1) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 80
bit_mask(  2) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 c0
bit_mask(  3) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 e0
bit_mask( 63) = 00 00 00 00 00 00 00 00 fe ff ff ff ff ff ff ff
bit_mask( 64) = 00 00 00 00 00 00 00 00 ff ff ff ff ff ff ff ff
bit_mask( 65) = 00 00 00 00 00 00 00 80 ff ff ff ff ff ff ff ff
bit_mask(127) = fe ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff
bit_mask(128) = ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff
Paul R
  • 208,748
  • 37
  • 389
  • 560
  • This doesn't look all fine - the results for `65` and `127` are not correct; am I missing something ? – FrankH. Apr 02 '14 at 10:14
  • 1
    Sorry about that - I realise now that I needed two masks, not one. It seems to be working correctly now. – Paul R Apr 02 '14 at 10:31
  • 1
    @PaulR, what is the %vx option in printf? This prints SSE registers? Which compiler? http://coliru.stacked-crooked.com/a/b8fd938a10e98b17 – Z boson Apr 02 '14 at 11:02
  • It's supported as an extension in some compilers, e.g. Apple's builds of gcc and clang on OS X (which is what I mainly use for development work) – Paul R Apr 02 '14 at 11:51
  • @PaulR so I guess this won't work on linux... I came up with a solution using a broadcast to set all to one, combined with a bitwise shift – user46317 Apr 02 '14 at 14:44
  • I don't know whether Apple's additions to gcc, clang et al have found their way into the Linux distributions yet, but it might be worth a try. Please post your solution - it's good to have as many different ideas and implementations as possible. – Paul R Apr 02 '14 at 15:10
  • 1
    @user46317: Paul's `__m128i bit_mask(int n)` function works on all x86 compilers. The only thing in this answer that doesn't work on GNU/Linux is the `%vx` conversion for printf. – Peter Cordes Apr 25 '22 at 01:55
2

You can use one of the methods from this question to generate a mask with the MS n bytes set to all ones. You would then just need to fix up any remaining bits when n is not a multiple of 8.

I suggest trying something like this:

- init vector A = all (8 bit) elements to the residual mask of n % 8 bits
- init vector B = mask of n / 8 bytes using one of the above-mentioned methods
- init vector C = mask of (n + 7) / 8 bytes using one of the above-mentioned methods
- result = A | B & C

So for example if n = 36:

A = f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0
B = ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00 00
C = ff ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00
==> ff ff ff ff f0 00 00 00 00 00 00 00 00 00 00 00

This would be branchless, as required, but it's probably of the order of ~10 instructions. There may be a more efficient method but I would need to give this some more thought.

Community
  • 1
  • 1
Paul R
  • 208,748
  • 37
  • 389
  • 560
1

The next two solutions are an alternative to Paul R's answer. These solutions are of interest when the masks are needed in the context of a performance critical loop.


SSE2

__m128i bit_mask_v2(unsigned int n){                      /* Create an __m128i vector with the n most significant bits set to 1  */
    __m128i ones_hi   = _mm_set_epi64x(-1,0);             /* Binary vector of bits 1...1 and 0...0                               */
    __m128i ones_lo   = _mm_set_epi64x(0,-1);             /* Binary vector of bits 0...0 and 1...1                               */
    __m128i cnst64    = _mm_set1_epi64x(64);
    __m128i cnst128   = _mm_set1_epi64x(128);

    __m128i shift     = _mm_cvtsi32_si128(n);             /* Move n to SSE register                                              */
    __m128i shift_hi  = _mm_subs_epu16(cnst64,shift);     /* Subtract with saturation                                            */
    __m128i shift_lo  = _mm_subs_epu16(cnst128,shift);   
    __m128i hi        = _mm_sll_epi64(ones_hi,shift_hi);  /* Shift the hi bits 64-n positions if 64-n>=0, else no shift          */
    __m128i lo        = _mm_sll_epi64(ones_lo,shift_lo);  /* Shift the lo bits 128-n positions if 128-n>=0, else no shift        */
               return   _mm_or_si128(lo,hi);              /* Merge hi and lo                                                     */
}


SSSE3 The SSSE3 case is more interesting. The pshufb instruction is used as a small lookup table. It took me some time to figure out the right combination of the (saturated) arithmetic and the constants.

__m128i bit_mask_SSSE3(unsigned int n){                   /* Create an __m128i vector with the n most significant bits set to 1   */
    __m128i sat_const = _mm_set_epi8(247,239,231,223,   215,207,199,191,   183,175,167,159,   151,143,135,127);  /* Constant used in combination with saturating addition */
    __m128i sub_const = _mm_set1_epi8(248);
    __m128i pshub_lut = _mm_set_epi8(0,0,0,0,   0,0,0,0,   
                          0b11111111, 0b11111110, 0b11111100, 0b11111000,
                          0b11110000, 0b11100000, 0b11000000, 0b10000000);

    __m128i shift_bc  = _mm_set1_epi8(n);                         /* Broadcast n to the 16 8-bit elements.                                */
    __m128i shft_byte = _mm_adds_epu8(shift_bc,sat_const);        /* The constants sat_const and sub_const are selected such that         */
    __m128i shuf_indx = _mm_sub_epi8(shft_byte,sub_const);        /* _mm_shuffle_epi8 can be used as a tiny lookup table                  */
                return  _mm_shuffle_epi8(pshub_lut,shuf_indx);    /* which finds the right bit pattern at the right position.             */
}


Functionality
For 1<=n<=128, which was specified by the OP, the functions bit_mask_Paul_R(n) (Paul R's answer), and bit_mask_v2(n) produce the same results:

bit_mask_Paul_R(  0) = FFFFFFFFFFFFFFFF 0000000000000000
bit_mask_Paul_R(  1) = 8000000000000000 0000000000000000
bit_mask_Paul_R(  2) = C000000000000000 0000000000000000
bit_mask_Paul_R(  3) = E000000000000000 0000000000000000
.....
bit_mask_Paul_R(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_Paul_R(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_Paul_R(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF


bit_mask_v2(  0) = 0000000000000000 0000000000000000
bit_mask_v2(  1) = 8000000000000000 0000000000000000
bit_mask_v2(  2) = C000000000000000 0000000000000000
bit_mask_v2(  3) = E000000000000000 0000000000000000
.....
bit_mask_v2(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_v2(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_v2(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF


bit_mask_SSSE3(  0) = 0000000000000000 0000000000000000
bit_mask_SSSE3(  1) = 8000000000000000 0000000000000000
bit_mask_SSSE3(  2) = C000000000000000 0000000000000000
bit_mask_SSSE3(  3) = E000000000000000 0000000000000000
.....
bit_mask_SSSE3(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_SSSE3(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_SSSE3(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF

For n=0 the most reasonable result is the zero vector, which is produced by bit_mask_v2(n) and bit_mask_SSSE3(n).


Performance
To get a rough impression of the performance of the different functions, the following piece of code is used:

__m128i sum = _mm_setzero_si128();
for (i=0;i<1000000000;i=i+1){
    sum=_mm_add_epi64(sum,bit_mask_Paul_R(i));   // or use next line instead 
//    sum=_mm_add_epi64(sum,bit_mask_v2(i));
//    sum=_mm_add_epi64(sum,bit_mask_SSSE3(i));
}
_mm_storeu_si128((__m128i*)x,sum);
printf("sum = %016lX %016lX\n", x[1],x[0]);

The performance of the code depends slightly on the type of instruction encoding. GCC options opts1 = -O3 -m64 -Wall -march=nehalem lead to non-vex encoded sse instructions, while opts2 = -O3 -m64 -Wall -march=sandybridge compiles to vex encoded avx128 instructions.

The results with gcc 5.4 are:

Cycles per iteration on Intel Skylake, estimated with: perf stat -d ./a.out
                     opts1       opts2
bit_mask_Paul_R       6.0         7.0
bit_mask_v2           3.8         3.3
bit_mask_SSSE3        3.0         3.0

In practice the performance will depend on the cpu type and the surrounding code. The performance of bit_mask_SSSE3 is limited by port 5 pressure; three instructions (one movd and the two pshufb-s) per iteration are handled by port 5.

With AVX2, a more efficient code is possible, see here.

wim
  • 3,702
  • 19
  • 23