Фильтрация вектора с помощью SIMD-инструкций (AVX-2 и AVX-512)

👋 Привет, dev.to!

Проблема

По причинам, которые выходят за рамки этой статьи, мне недавно потребовалось измерить, насколько быстро одно ядро CPU может перебрать несортированный список беззнаковых 32-битных целых чисел (u32) и вывести отсортированный список индексов, в которых значение находится в заданном интервале.

Другими словами, допустим, мы получаем на вход годы

[1992, 2018, 1934, 2002, 2022, 1998, 1972, 1996].

Учитывая интервал [1982...2000], мы определим, что значения в этом интервале следующие

[1992, _, _, _, _, 1998, _, 1996]

Поэтому нашим результатом должны быть их соответствующие индексы

[0, 5, 7].

На идиоматическом языке rust такая функция могла бы выглядеть следующим образом

use std::ops::RangeInclusive;

pub fn filter_vec_scalar(
    input: &[u32],
    range: RangeInclusive<u32>,
    output: &mut Vec<u32>)
{
    output.clear();
    output.reserve(input.len());
    output.extend(
        input
            .iter()
            .enumerate()
            .filter(|&(_, el)| range.contains(el))
            .map(|(id, _)| id as u32),
    );
}
Войти в полноэкранный режим Выйти из полноэкранного режима

Этот код кажется довольно тривиальным. Конечно, rustc сможет
векторизовать его самостоятельно, и это будет очень короткий пост в блоге, верно? На самом деле нет.

Сгенерированный ассемблерный код — не самое интересное чтение (https://godbolt.org/z/x3sK4fvhP), поэтому я не буду показывать его здесь… но поверьте мне. Это просто скучный скалярный код.

На экземпляре EC2 c5n.large (Intel Xeon Platinum 3 ГГц (Skylake 8124)), с сохранением 50% элементов, этот код имеет пропускную способность 170 миллионов u32s/сек.

К концу этой статьи наш код будет в 20 раз быстрее.

В этой статье будет показано, как можно использовать SIMD-инструкции для оптимизации этого кода
и подчеркнет различия между AVX2 и AVX512.

Без ветвления

Вам не терпится перейти к SIMD-коду… Я вас не виню. Мы дойдем до этого. Однако в качестве промежуточного шага давайте посмотрим, можем ли мы улучшить этот скалярный код.

Как мы видели в моей предыдущей статье в блоге, неправильное предсказание ветвей может действительно убить производительность процессора. Если наши данные довольно сбалансированы и случайны, все равно может быть полезно подправить наш код, чтобы удалить ветви.

Хитрость проста. Мы всегда будем копировать наш идентификатор в хвост выходного вектора, но увеличивать хвостовой указатель будем только в том случае, если мы действительно хотим сохранить этот идентификатор.
Теперь наш код выглядит следующим образом.

pub fn filter_vec_nobranch(
    input: &[u32],
    range: RangeInclusive<u32>,
    output: &mut Vec<u32>) {
    output.clear();
    output.resize(input.len(), 0u32);
    let mut output_len = 0;
    for (id, &el) in input.iter().enumerate() {
        output[output_len] = id as u32;
        output_len += if range.contains(&el) { 1 } else { 0 };
    }
    output.truncate(output_len);
}
Вход в полноэкранный режим Выход из полноэкранного режима

Обратите внимание, что я не поддался на призыв сирены использовать здесь небезопасный код.
Вероятно, это все равно не принесло бы пользы.

Наш новый код теперь работает со скоростью 300 миллионов u32s/сек.

AVX2

Как вы, вероятно, знаете, процессоры обычно предлагают специальные инструкции, которые работают с
несколькими частями данных одновременно. Эти инструкции выпускались в течение многих лет
в виде «наборов инструкций».

Набор инструкций работает с определенным регистром с заданной шириной бита. Например,
SSE2 работает с 128-битовыми регистрами. Некоторые инструкции будут обрабатывать такой регистр как целые числа 8×16 бит,
некоторые инструкции будут обрабатывать этот регистр как 4×32-битные целые числа и т.д.

Как программист, мы обычно не пишем ассемблер в явном виде. Вместо этого мы вызываем встроенные функции, называемые intrinsics. Компилятор затем заботится о том, чтобы выдать нужные инструкции ассемблера от нашего имени.
В дальнейшем я буду иногда использовать имя инструкции, а иногда имя intrinsics. Не волнуйтесь, я просто пытаюсь вас запутать.

AVX2 является разумной целью для нашей маленькой проблемы: в настоящее время AVX/AVX2 широко поддерживается в 64-битных x86 CPU.
Большинство x86 Cloud VM будут поддерживать этот набор инструкций, и вы можете ожидать, что каждый ноутбук AMD или Intel будет работать с ним.
Он оперирует 256-битными регистрами.

Давайте скроем ужасные подробности инструкций за парой
хорошо названных функций, и изложим идею нашей реализации SIMD.

// We store 8x32-bits integers in one 256 bits register.
const NUM_LANES: usize = 8;

pub fn filter_vec_avx2(
    input: &[u32],
    range: RangeInclusive<u32>,
    output: &mut Vec<u32>) {
    let mut output_tail = output;
    assert_eq!(input.len() % NUM_LANES, 0);j
    // As we read our input 8 ints at a time, we will store
    // their respective indexes in `ids`.
    let mut ids: __m128i = from_u32s([0, 1, 2, 3, 4, 5, 6, 7]);
    const SHIFT: __m128i = from_u32s([NUM_LANES as u32; NUM_LANES]]);
    for _ in 0..input.len() / 8 {
        // Load 8 ints into our SIMD register.
        let els: __m256i = load_unaligned(input[i * 8..].as_ptr());
        // Identify the elements in the register that we
        // want to keep as 8-bit bitset.
        let keeper_bitset: u8 = compute_filter_bitset(els, range);
        // Compact our elements, putting all of the elements
        // retained on the left.
        let filtered_els: __m256i = compact(els, keeper_bitset);
        // Write all our 8 elements into the output...
        store_unaligned(output_tail, filtered_els);

        let added_len = keeper_bitset.count_ones();
        output_tail = output_tail.offset(added_len);

        idx = op_add(ids, SHIFT);
    }
    let output_len = output_tail.offset_from(output) as usize;
    output.set_len(output_len);
}
Вход в полноэкранный режим Выход из полноэкранного режима

Давайте распакуем это.

Вполне ожидаемо, мы используем тот же трюк, что и в решении без ветвления.
На каждой итерации цикла мы загружаем и храним данные в одном SIMD-регистре, то есть 8 целых чисел.
Однако мы увеличиваем указатель output_tail в зависимости от количества сохраняемых элементов.

Мы храним ids наших текущих последовательных 8 элементов
в 256-битном регистре.

Затем нам нужен способ

  • определить набор значений, которые мы хотим сохранить, и вернуть его в виде 8-битного набора. Мы назвали эту операцию compute_filter_bitset.
  • переместить значения, сохраняемые этим набором битов, и переместить их в левую часть нашего регистра. Мы назвали эту операцию compact.

К сожалению, AVX2 не предлагает никаких очевидных способов реализации.

Программирование SIMD действительно похоже на игру в лего с мешком кирпичиков лего странной формы.
У Intel есть очень хороший справочный веб-сайт, на котором перечислены различные доступные инструкции.

Это длинный список. Давайте посмотрим, сможем ли мы отфильтровать этот список, чтобы сделать его более удобоваримым для нас.

Хотя регистр может быть использован как для операций с плавающей точкой, так и для целочисленных операций, в целочисленных операциях интринструкции будут ссылаться на связанный тип данных как __m256i. Это хорошее ключевое слово для нас (входящая ирония!).

Компактный

Начнем с compact.
В AVX2 нет инструкции для этого. В 128-битном мире,

позволяет вам применить перестановку к байтам вашего регистра.

Эта инструкция — самая популярная игрушка для всех. Она не только невероятно универсальна, но и имеет пропускную способность 1 и
задержка равна 1.

К сожалению, гримуар инструкций AVX2 не содержит такой инструкции.
Эквивалентная инструкция существует и называется vpshufd, но есть одна загвоздка: она применяет только
применяет только две перестановки в пределах двух 128-битных полос, что совершенно бесполезно для нас.
Это очень распространенный паттерн в инструкциях AVX2. Инструкции, пересекающие эту страшную 128-битную полосу, встречаются редко.
Это частый источник головной боли для разработчиков.

К счастью, применение перестановки над u32s (что нам и нужно) действительно возможно,
с помощью инструкции VPERMPS __mm256_permutevar8x32_epi32.

Эта инструкция имеет большую задержку, чем PSHUB (она имеет задержку 3), но это не повлияло на общую пропускную способность.

Другая проблема, с которой мы столкнулись, заключается в том, что она ожидает отображение в качестве входных данных, в то время как то, что мы имеем в руках, является набором битов.
Нам нужен какой-то способ эффективно преобразовать наши keeper_bitsetsets в отображение, которое ожидает наша инструкция.

Это обычная проблема с обычным решением. Наша 8-битная битовая маска может принимать только 256 значений, поэтому мы
можем просто предварительно вычислить это отображение в массив из 256 значений и передать его вместе с кодом. Преобразование набора битов в отображение будет состоять из простого перебора.

Теперь наш код выглядит следующим образом:


const BITSET_TO_MAPPING: [__m256i] = /* ... */;

#[inline]
unsafe fn compact(data: __m256i, mask: u8) -> DataType {
    let vperm_mask = BITSET_TO_MAPPING[mask as usize];
    _mm256_permutevar8x32_epi32(data, vperm_mask)
}
Вход в полноэкранный режим Выход из полноэкранного режима

На этом этапе нам все еще нужно объявить наш массив const BITSET_TO_MAPPING размером __m256i, но rust не дает нам возможности написать для него литерал.

К счастью, @burntsushi (спасибо) поделился обходным решением в выпуске rust-lang на github. Хитрость заключается в том, чтобы переинтерпретировать (нетегированный) союз для этого.

const fn from_u32x8(vals: [u32; 8]) -> __m256i {
    union U8x32 {
        vector: DataType,
        vals: [u32; __m256i],
    }
    unsafe { U8x32 { vals }.vector }
}
Войти в полноэкранный режим Выход из полноэкранного режима

compute_filter_bitset

Набор битов вычислительного фильтра, как ни странно, вызвал у меня больше всего проблем.

В AVX256 можно сравнивать 8 пар i32 за раз, используя интринсику use _mm256_cmpgt_epi32.

Первое препятствие: мы бы предпочли сравнивать беззнаковые инты. Для моего случая использования мне не требовалась поддержка всего диапазона u32, поэтому я просто добавил в код утверждение для ограничения значений до 0..i32::MAX.
Я считаю, что распространенным обходным решением является переворачивание старшего бита перед сравнением с помощью xor.
Эта операция действительно является монотонным отображением из u32 в i32.

Второе препятствие, наш оператор на самом деле строго больше, а мы бы предпочли больше или равно.
Для решения этой проблемы я использовал отрицание по маске. $$neg left(a > b right) Leftrightarrow a leq b$$$.

Последняя проблема заключается в том, что выход __mm256_cmpgt_epi32 не является 8-битной маской.
Вместо этого, он оставляет нам значение __m256, где все биты в данной 32-битной полосе равны 1, если
левый операнд был больше правого операнда для данной полосы.

Я не смог найти операцию для сбора одного бита из каждой полосы для формирования однобайтовой маски.
Я использовал _mm256_movemask_epi8 для сбора msb всех байтов, а затем несколько строк скалярного кода
чтобы извлечь из этого целого числа мое целое число с маской 8 бит 32 бита. Это было медленно и казалось неубедительным. Наверняка был способ получше.

Тогда я сделал то, что должен сделать каждый здравомыслящий инженер. Я спросил у Twitter (ладно, я немного вру, на момент написания твита я играл с SSE2).

Твит превратился в волшебный поток, в котором многие разработчики начали делиться своими трюками.
Если вам столько же лет, сколько и мне, вы, возможно, помните, как зарождался интернет.
Во Франции интернет-провайдеры крутили рекламные ролики, в которых рассказывали о том, что интернет — не их услуга в частности — просто потрясающий. Этот поток имеет те же вибрации, что и эти рекламные ролики.

В конце концов, Джефф Лэнгдейл пришел с идеальным ответом. Идеальная инструкция существует. Инструкция, которая извлекает самое важное из наших 8 x 32 битовых полос и возвращает один байт. Инструкция называется VMOVMSKPS. Я не мог найти ее, потому что она представлена как инструкция для извлечения знака из кучи 32-битных плавающих чисел с плавающей запятой.
Конечно, я и не думал заглядывать в интринсики float!

Большой привет Джеффу. Если вы дочитали до этого места, вам, вероятно, стоит заглянуть в его удивительный блог, в котором он обсуждает оптимизацию процессора и программирование SIMD.

Вот наша функция в конце.

unsafe fn compute_filter_bitset(
    val: __m256i,
    range: RangeInclusive<__m256i>) -> u8 {
    let too_low = op_greater(*range.start(), val);
    let too_high = op_greater(val,*range.end());
    let outside: _m256 = transmute::<_m256i, _m256>(op_or(too_low, too_high));
    255u8 - _mm256_movemask_ps(outside) as u8
}
Войти в полноэкранный режим Выход из полноэкранного режима

Собираем все вместе

use std::arch::x86_64::_mm256_add_epi32 as op_add;
use std::arch::x86_64::_mm256_cmpgt_epi32 as op_greater;
use std::arch::x86_64::_mm256_lddqu_si256 as load_unaligned;
use std::arch::x86_64::_mm256_storeu_si256 as store_unaligned;
use std::arch::x86_64::_mm256_or_si256 as op_or;
use std::arch::x86_64::_mm256_set1_epi32 as set1;
use std::arch::x86_64::_mm256_movemask_ps as extract_msb;
use std::arch::x86_64::_mm256_permutevar8x32_epi32 as permute;
use std::arch::x86_64::{__m256, __m256i};
use std::ops::RangeInclusive;
use std::mem::transmute;

const NUM_LANES: usize = 8;

pub fn filter_vec(
    input: &[u32],
    range: RangeInclusive<u32>,
    output: &mut Vec<u32>) {
    assert_eq!(input.len() % NUM_LANES, 0);
    // We restrict the accepted boundary, because unsigned integers & SIMD don't
    // play well.
    let accepted_range = 0u32..(i32::MAX as u32);
    assert!(accepted_range.contains(range.start()));
    assert!(accepted_range.contains(range.end()));
    output.clear();
    output.reserve(input.len());
    let num_words = input.len() / NUM_LANES;
    unsafe {
        let output_len = filter_vec_avx2_aux(
            input.as_ptr() as *const __m256i,
            range,
            output.as_mut_ptr(),
            num_words,
        );
        output.set_len(output_len);
    }
}

unsafe fn filter_vec_avx2_aux(
    mut input: *const __m256i,
    range: RangeInclusive<u32>,
    output: *mut u32,
    num_words: usize,
) -> usize {
    let mut output_tail = output;
    let range_simd =
        set1(*range.start() as i32)..=set1(*range.end() as i32);
    let mut ids = from_u32x8([0, 1, 2, 3, 4, 5, 6, 7]);
    const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
    for _ in 0..num_words {
        let word = load_unaligned(input);
        let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
        let added_len = keeper_bitset.count_ones();
        let filtered_doc_ids = compact(ids, keeper_bitset);
        store_unaligned(
            output_tail as *mut __m256i,
            filtered_doc_ids,
        );
        output_tail = output_tail.offset(added_len as isize);
        ids = op_add(ids, SHIFT);
        input = input.offset(1);
    }
    output_tail.offset_from(output) as usize
}

#[inline]
unsafe fn compact(data: __m256i, mask: u8) -> __m256i {
    let vperm_mask = BITSET_TO_MAPPING[mask as usize];
    permute(data, vperm_mask)
}

#[inline]
unsafe fn compute_filter_bitset(
    val: __m256i,
    range: RangeInclusive<__m256i>) -> u8 {
    let too_low = op_greater(*range.start(), val);
    let too_high = op_greater(val,*range.end());
    let outside: __m256 = transmute::<__m256i, __m256>(op_or(too_low, too_high));
    255u8 - extract_msb(outside) as u8
}

const fn from_u32x8(vals: [u32; NUM_LANES]) -> __m256i {
    union U8x32 {
        vector: __m256i,
        vals: [u32; NUM_LANES],
    }
    unsafe { U8x32 { vals }.vector }
}

const BITSET_TO_MAPPING: [__m256i; 256] = [
    from_u32x8([0, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 0, 0, 0, 0, 0, 0]),
    from_u32x8([2, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 0, 0, 0, 0, 0]),
    from_u32x8([3, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 3, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 0, 0, 0, 0, 0]),
    from_u32x8([2, 3, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 3, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 0, 0, 0, 0]),
    from_u32x8([4, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 4, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 0, 0, 0, 0, 0]),
    from_u32x8([2, 4, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 4, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 0, 0, 0, 0]),
    from_u32x8([3, 4, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 0, 0, 0, 0, 0]),
    from_u32x8([1, 3, 4, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 0, 0, 0, 0]),
    from_u32x8([2, 3, 4, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 0, 0, 0, 0]),
    from_u32x8([1, 2, 3, 4, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 0, 0, 0]),
    from_u32x8([5, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 5, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 5, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 5, 0, 0, 0, 0, 0]),
    from_u32x8([2, 5, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 5, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 5, 0, 0, 0, 0]),
    from_u32x8([3, 5, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 5, 0, 0, 0, 0, 0]),
    from_u32x8([1, 3, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 5, 0, 0, 0, 0]),
    from_u32x8([2, 3, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 5, 0, 0, 0, 0]),
    from_u32x8([1, 2, 3, 5, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 5, 0, 0, 0]),
    from_u32x8([4, 5, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 5, 0, 0, 0, 0, 0]),
    from_u32x8([1, 4, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 5, 0, 0, 0, 0]),
    from_u32x8([2, 4, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 5, 0, 0, 0, 0]),
    from_u32x8([1, 2, 4, 5, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 5, 0, 0, 0]),
    from_u32x8([3, 4, 5, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 5, 0, 0, 0, 0]),
    from_u32x8([1, 3, 4, 5, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 5, 0, 0, 0]),
    from_u32x8([2, 3, 4, 5, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 5, 0, 0, 0]),
    from_u32x8([1, 2, 3, 4, 5, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 5, 0, 0]),
    from_u32x8([6, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 6, 0, 0, 0, 0, 0]),
    from_u32x8([2, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 6, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 6, 0, 0, 0, 0]),
    from_u32x8([3, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 6, 0, 0, 0, 0, 0]),
    from_u32x8([1, 3, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 6, 0, 0, 0, 0]),
    from_u32x8([2, 3, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 6, 0, 0, 0, 0]),
    from_u32x8([1, 2, 3, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 6, 0, 0, 0]),
    from_u32x8([4, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 6, 0, 0, 0, 0, 0]),
    from_u32x8([1, 4, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 6, 0, 0, 0, 0]),
    from_u32x8([2, 4, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 6, 0, 0, 0, 0]),
    from_u32x8([1, 2, 4, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 6, 0, 0, 0]),
    from_u32x8([3, 4, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 6, 0, 0, 0, 0]),
    from_u32x8([1, 3, 4, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 6, 0, 0, 0]),
    from_u32x8([2, 3, 4, 6, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 6, 0, 0, 0]),
    from_u32x8([1, 2, 3, 4, 6, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 6, 0, 0]),
    from_u32x8([5, 6, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 5, 6, 0, 0, 0, 0, 0]),
    from_u32x8([1, 5, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 5, 6, 0, 0, 0, 0]),
    from_u32x8([2, 5, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 5, 6, 0, 0, 0, 0]),
    from_u32x8([1, 2, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 5, 6, 0, 0, 0]),
    from_u32x8([3, 5, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 5, 6, 0, 0, 0, 0]),
    from_u32x8([1, 3, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 5, 6, 0, 0, 0]),
    from_u32x8([2, 3, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 5, 6, 0, 0, 0]),
    from_u32x8([1, 2, 3, 5, 6, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 5, 6, 0, 0]),
    from_u32x8([4, 5, 6, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 5, 6, 0, 0, 0, 0]),
    from_u32x8([1, 4, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 5, 6, 0, 0, 0]),
    from_u32x8([2, 4, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 5, 6, 0, 0, 0]),
    from_u32x8([1, 2, 4, 5, 6, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 5, 6, 0, 0]),
    from_u32x8([3, 4, 5, 6, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 5, 6, 0, 0, 0]),
    from_u32x8([1, 3, 4, 5, 6, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 5, 6, 0, 0]),
    from_u32x8([2, 3, 4, 5, 6, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 5, 6, 0, 0]),
    from_u32x8([1, 2, 3, 4, 5, 6, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 5, 6, 0]),
    from_u32x8([7, 0, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([1, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 7, 0, 0, 0, 0, 0]),
    from_u32x8([2, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 7, 0, 0, 0, 0, 0]),
    from_u32x8([1, 2, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 7, 0, 0, 0, 0]),
    from_u32x8([3, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 7, 0, 0, 0, 0, 0]),
    from_u32x8([1, 3, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 7, 0, 0, 0, 0]),
    from_u32x8([2, 3, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 7, 0, 0, 0, 0]),
    from_u32x8([1, 2, 3, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 7, 0, 0, 0]),
    from_u32x8([4, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 7, 0, 0, 0, 0, 0]),
    from_u32x8([1, 4, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 7, 0, 0, 0, 0]),
    from_u32x8([2, 4, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 7, 0, 0, 0, 0]),
    from_u32x8([1, 2, 4, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 7, 0, 0, 0]),
    from_u32x8([3, 4, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 7, 0, 0, 0, 0]),
    from_u32x8([1, 3, 4, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 7, 0, 0, 0]),
    from_u32x8([2, 3, 4, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 7, 0, 0, 0]),
    from_u32x8([1, 2, 3, 4, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 7, 0, 0]),
    from_u32x8([5, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 5, 7, 0, 0, 0, 0, 0]),
    from_u32x8([1, 5, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 5, 7, 0, 0, 0, 0]),
    from_u32x8([2, 5, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 5, 7, 0, 0, 0, 0]),
    from_u32x8([1, 2, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 5, 7, 0, 0, 0]),
    from_u32x8([3, 5, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 5, 7, 0, 0, 0, 0]),
    from_u32x8([1, 3, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 5, 7, 0, 0, 0]),
    from_u32x8([2, 3, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 5, 7, 0, 0, 0]),
    from_u32x8([1, 2, 3, 5, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 5, 7, 0, 0]),
    from_u32x8([4, 5, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 5, 7, 0, 0, 0, 0]),
    from_u32x8([1, 4, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 5, 7, 0, 0, 0]),
    from_u32x8([2, 4, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 5, 7, 0, 0, 0]),
    from_u32x8([1, 2, 4, 5, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 5, 7, 0, 0]),
    from_u32x8([3, 4, 5, 7, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 5, 7, 0, 0, 0]),
    from_u32x8([1, 3, 4, 5, 7, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 5, 7, 0, 0]),
    from_u32x8([2, 3, 4, 5, 7, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 5, 7, 0, 0]),
    from_u32x8([1, 2, 3, 4, 5, 7, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 5, 7, 0]),
    from_u32x8([6, 7, 0, 0, 0, 0, 0, 0]),
    from_u32x8([0, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([1, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 1, 6, 7, 0, 0, 0, 0]),
    from_u32x8([2, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 2, 6, 7, 0, 0, 0, 0]),
    from_u32x8([1, 2, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 2, 6, 7, 0, 0, 0]),
    from_u32x8([3, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 3, 6, 7, 0, 0, 0, 0]),
    from_u32x8([1, 3, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 3, 6, 7, 0, 0, 0]),
    from_u32x8([2, 3, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 3, 6, 7, 0, 0, 0]),
    from_u32x8([1, 2, 3, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 3, 6, 7, 0, 0]),
    from_u32x8([4, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 4, 6, 7, 0, 0, 0, 0]),
    from_u32x8([1, 4, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 4, 6, 7, 0, 0, 0]),
    from_u32x8([2, 4, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 4, 6, 7, 0, 0, 0]),
    from_u32x8([1, 2, 4, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 4, 6, 7, 0, 0]),
    from_u32x8([3, 4, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 3, 4, 6, 7, 0, 0, 0]),
    from_u32x8([1, 3, 4, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 3, 4, 6, 7, 0, 0]),
    from_u32x8([2, 3, 4, 6, 7, 0, 0, 0]),
    from_u32x8([0, 2, 3, 4, 6, 7, 0, 0]),
    from_u32x8([1, 2, 3, 4, 6, 7, 0, 0]),
    from_u32x8([0, 1, 2, 3, 4, 6, 7, 0]),
    from_u32x8([5, 6, 7, 0, 0, 0, 0, 0]),
    from_u32x8([0, 5, 6, 7, 0, 0, 0, 0]),
    from_u32x8([1, 5, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 1, 5, 6, 7, 0, 0, 0]),
    from_u32x8([2, 5, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 2, 5, 6, 7, 0, 0, 0]),
    from_u32x8([1, 2, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 2, 5, 6, 7, 0, 0]),
    from_u32x8([3, 5, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 3, 5, 6, 7, 0, 0, 0]),
    from_u32x8([1, 3, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 3, 5, 6, 7, 0, 0]),
    from_u32x8([2, 3, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 2, 3, 5, 6, 7, 0, 0]),
    from_u32x8([1, 2, 3, 5, 6, 7, 0, 0]),
    from_u32x8([0, 1, 2, 3, 5, 6, 7, 0]),
    from_u32x8([4, 5, 6, 7, 0, 0, 0, 0]),
    from_u32x8([0, 4, 5, 6, 7, 0, 0, 0]),
    from_u32x8([1, 4, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 1, 4, 5, 6, 7, 0, 0]),
    from_u32x8([2, 4, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 2, 4, 5, 6, 7, 0, 0]),
    from_u32x8([1, 2, 4, 5, 6, 7, 0, 0]),
    from_u32x8([0, 1, 2, 4, 5, 6, 7, 0]),
    from_u32x8([3, 4, 5, 6, 7, 0, 0, 0]),
    from_u32x8([0, 3, 4, 5, 6, 7, 0, 0]),
    from_u32x8([1, 3, 4, 5, 6, 7, 0, 0]),
    from_u32x8([0, 1, 3, 4, 5, 6, 7, 0]),
    from_u32x8([2, 3, 4, 5, 6, 7, 0, 0]),
    from_u32x8([0, 2, 3, 4, 5, 6, 7, 0]),
    from_u32x8([1, 2, 3, 4, 5, 6, 7, 0]),
    from_u32x8([0, 1, 2, 3, 4, 5, 6, 7]),
];
Вход в полноэкранный режим Выход из полноэкранного режима

Наш код AVX2 обрабатывает 3,65 миллиарда элементов в секунду.
Здорово. Это в 8 раз быстрее, чем скалярное решение без ветвей.

AVX-512 на помощь

Последние процессоры Intel и грядущие процессоры AMD Zen 4 работают с новым набором инструкций
под названием AVX-512. Он оперирует целыми числами 16 x 32 бита за раз.

Но AVX-512 — это не просто удвоение числа дорожек для игры!
Он добавляет множество новых инструкций, чтобы облегчить нам жизнь.

Например, _mm512_cmple_epi32_mask — это внутренняя инструкция, которая сравнивает целые числа 16 x 32 бита.
двух __m512i. Теперь на выходе получается именно то, что мы ожидали: это 16-битный набор бит.

Что насчет сжатия и нашей уродливой карты? AVX-512 поставляется с инструкцией (intrinsics называется _mm512_mask_compressstoreu_epi32), которая делает именно то, что мы хотим!

use std::arch::x86_64::_mm512_add_epi32 as op_add;
use std::arch::x86_64::_mm512_cmple_epi32_mask as op_less_or_equal;
use std::arch::x86_64::_mm512_loadu_epi32 as load_unaligned;
use std::arch::x86_64::_mm512_set1_epi32 as set1;
use std::arch::x86_64::_mm512_mask_compressstoreu_epi32 as compress;
use std::arch::x86_64::__m512i;
use std::ops::RangeInclusive;

const NUM_LANES: usize = 16;

pub fn filter_vec(input: &[u32], range: RangeInclusive<u32>, output: &mut Vec<u32>) {
    assert_eq!(input.len() % NUM_LANES, 0);
    // We restrict the accepted boundary, because unsigned integers & SIMD don't
    // play well.
    let accepted_range = 0u32..(i32::MAX as u32);
    assert!(accepted_range.contains(range.start()));
    assert!(accepted_range.contains(range.end()));
    output.clear();
    output.reserve(input.len());
    let num_words = input.len() / NUM_LANES;
    unsafe {
        let output_len = filter_vec_aux(
            input.as_ptr(),
            range,
            output.as_mut_ptr(),
            num_words,
        );
        output.set_len(output_len);
    }
}

pub unsafe fn filter_vec_aux(
    mut input: *const u32,
    range: RangeInclusive<u32>,
    output: *mut u32,
    num_words: usize,
) -> usize {
    let mut output_end = output;
    let range_simd =
        set1(*range.start() as i32)..=set1(*range.end() as i32);
    let mut ids = from_u32x16([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
    const SHIFT: __m512i = from_u32x16([NUM_LANES as u32; NUM_LANES]);
    for _ in 0..num_words {
        let word = load_unaligned(input as *const i32);
        let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
        compress(output_end as *mut u8, keeper_bitset, ids);
        let added_len = keeper_bitset.count_ones();
        output_end = output_end.offset(added_len as isize);
        ids = op_add(ids, SHIFT);
        input = input.offset(1);
    }
    output_end.offset_from(output) as usize
}

#[inline]
unsafe fn compute_filter_bitset(
    val: __m512i,
    range: RangeInclusive<__m512i>) -> u16 {
    let low = op_less_or_equal(*range.start(), val);
    let high = op_less_or_equal(val, *range.end());
    low & high
}

const fn from_u32x16(vals: [u32; NUM_LANES]) -> __m512i {
    union U8x64 {
        vector: __m512i,
        vals: [u32; NUM_LANES],
    }
    unsafe { U8x64 { vals }.vector }
}

Вход в полноэкранный режим Выход из полноэкранного режима

Эта версия значительно проще и теперь работает со скоростью 8,6 миллиардов целых чисел в секунду.

Заключение

AVX512 действительно высасывает удовольствие из SIMD-программирования.
В итоге мой код стал значительно проще и быстрее.

Спасибо всем, кто прочитал эту статью в блоге! Изначально она была написана нашим соучредителем Полом Мазурелом.

Оцените статью
devanswers.ru
Добавить комментарий