NTT

2023 年 1 月 20 日

目录

多项式 NTT

#ifndef ALGO_MATH_POLY_NTT
#define ALGO_MATH_POLY_NTT

#include "../../base.hpp"
#include "../../other/modint/modint-concept.hpp"

#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>

namespace detail {

u32 ntt_size = 0;

} // namespace detail

/////////////////////

#ifndef ALGO_DISABLE_NTT_CLASSICAL

#ifndef ALGO_DISABLE_NTT_RADIX_4

// classical-radix-4
#define ALGO_DETAIL_NTT detail::ntt_classical_basic4
#define ALGO_DETAIL_INTT detail::intt_classical_basic4
#include "ntt-classical-radix-4-basic.hpp"
#ifndef ALGO_DISABLE_SIMD_AVX2
#define ALGO_DETAIL_NTT_AVX detail::ntt_classical_avx4
#define ALGO_DETAIL_INTT_AVX detail::intt_classical_avx4
#include "ntt-classical-radix-4-avx.hpp"
#endif // ALGO_DISABLE_SIMD_AVX2

#else

// classical-radix-2
#define ALGO_DETAIL_NTT detail::ntt_classical_basic
#define ALGO_DETAIL_INTT detail::intt_classical_basic
#include "ntt-classical-radix-2-basic.hpp"
#ifndef ALGO_DISABLE_SIMD_AVX2
#define ALGO_DETAIL_NTT_AVX detail::ntt_classical_avx
#define ALGO_DETAIL_INTT_AVX detail::intt_classical_avx
#include "ntt-classical-radix-2-avx.hpp"
#endif // ALGO_DISABLE_SIMD_AVX2

#endif // ALGO_DISABLE_NTT_RADIX_4

#else

// twisted-radix-2
#define ALGO_DETAIL_NTT detail::ntt_twisted_basic
#define ALGO_DETAIL_INTT detail::intt_twisted_basic
#include "ntt-twisted-radix-2-basic.hpp"
#ifndef ALGO_DISABLE_SIMD_AVX2
#define ALGO_DETAIL_NTT_AVX detail::ntt_twisted_avx
#define ALGO_DETAIL_INTT_AVX detail::intt_twisted_avx
#include "ntt-twisted-radix-2-avx.hpp"
#endif // ALGO_DISABLE_SIMD_AVX2

#endif

template <class ModT, bool aligned = true>
void ntt(std::span<ModT> f) {
  assert(std::has_single_bit(f.size()));
  detail::ntt_size += f.size();
#ifndef ALGO_DISABLE_SIMD_AVX2
  if (montgomery_modint_concept<ModT> && f.size() > 16) {
    ALGO_DETAIL_NTT_AVX<ModT, aligned>(f);
  } else {
#endif
    ALGO_DETAIL_NTT(f);
#ifndef ALGO_DISABLE_SIMD_AVX2
  }
#endif
}

template <class ModT, bool aligned = true>
void intt(std::span<ModT> f) {
  assert(std::has_single_bit(f.size()));
  detail::ntt_size += f.size();
#ifndef ALGO_DISABLE_SIMD_AVX2
  if (montgomery_modint_concept<ModT> && f.size() > 16) {
    ALGO_DETAIL_INTT_AVX<ModT, aligned>(f);
  } else {
#endif
    ALGO_DETAIL_INTT(f);
#ifndef ALGO_DISABLE_SIMD_AVX2
  }
#endif
}

#undef ALGO_DETAIL_NTT
#undef ALGO_DETAIL_NTT_AVX
#undef ALGO_DETAIL_INTT
#undef ALGO_DETAIL_INTT_AVX

#endif // ALGO_MATH_POLY_NTT