From c5ae55c3f1ad89d46143873482695d8f695b9376 Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Wed, 19 Nov 2025 20:38:38 +0900 Subject: [PATCH] feat: add ntt --- src/6-math/fft.cpp | 46 ++++++++++++++++++++++++++++++++++++++++++- src/common/common.hpp | 1 + 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/6-math/fft.cpp b/src/6-math/fft.cpp index 22a9265..7c75d35 100644 --- a/src/6-math/fft.cpp +++ b/src/6-math/fft.cpp @@ -1,5 +1,4 @@ #include "../common/common.hpp" - namespace fft { using real_t = double; using cpx = complex; @@ -78,3 +77,48 @@ vector multiply_mod(const vector &a, const vector &b, const ull mod) return ret; } } // namespace fft +namespace ntt { +constexpr ll MOD = (119 << 23) + 1, root = 3; // = 998244353 +// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21 +// and 483 << 21 (same root). The last two are > 10^9. +ll modpow(ll b, ll e) { + ll ans = 1; + for (; e; b = b * b % MOD, e /= 2) + if (e & 1) ans = ans * b % MOD; + return ans; +} +void ntt(vector &a) { + int n = sz(a), L = 31 - __builtin_clz(n); + static vector rt(2, 1); + for (static int k = 2, s = 2; k < n; k *= 2, s++) { + rt.resize(n); + ll z[] = {1, modpow(root, MOD >> s)}; + for (int i = k; i < 2 * k; i++) + rt[i] = rt[i / 2] * z[i & 1] % MOD; + } + vector rev(n); + for (int i = 0; i < n; i++) rev[i] = (rev[i / 2] | (i & 1) << L) / 2; + for (int i = 0; i < n; i++) + if (i < rev[i]) swap(a[i], a[rev[i]]); + for (int k = 1; k < n; k *= 2) + for (int i = 0; i < n; i += 2 * k) + for (int j = 0; j < k; j++) { + ll z = rt[j + k] * a[i + j + k] % MOD, &ai = a[i + j]; + a[i + j + k] = ai - z + (z > ai ? MOD : 0); + ai += (ai + z >= MOD ? z - MOD : z); + } +} +vector multiply(const vector &a, const vector &b) { + if (a.empty() || b.empty()) return {}; + int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), + n = 1 << B; + int inv = modpow(n, MOD - 2); + vector L(a), R(b), out(n); + L.resize(n), R.resize(n); + ntt(L), ntt(R); + for (int i = 0; i < n; i++) + out[-i & (n - 1)] = (ll)L[i] * R[i] % MOD * inv % MOD; + ntt(out); + return {out.begin(), out.begin() + s}; +} +} // namespace ntt \ No newline at end of file diff --git a/src/common/common.hpp b/src/common/common.hpp index c085fc3..e28326e 100644 --- a/src/common/common.hpp +++ b/src/common/common.hpp @@ -5,6 +5,7 @@ using namespace std; using ll = long long; using ld = long double; +using ull = unsigned long long; using pii = pair; using pll = pair;