diff --git a/src/6-math/basic_sqrt_time_algorithms.cpp b/src/6-math/0_sqrt_time_algorithms.hpp similarity index 52% rename from src/6-math/basic_sqrt_time_algorithms.cpp rename to src/6-math/0_sqrt_time_algorithms.hpp index 4a5149c..9e1b926 100644 --- a/src/6-math/basic_sqrt_time_algorithms.cpp +++ b/src/6-math/0_sqrt_time_algorithms.hpp @@ -1,35 +1,32 @@ +#pragma once #include "../common/common.hpp" - // 1. Finding Divisors in O(sqrt(x)) // INPUT: Given a natural number x. // OUTPUT: Find all the divisors of x. // TIME COMPLEXITY: O(sqrt(x)). -vector di; -void findingDivisors(int x) { - for (int i = 1; i <= sqrt(x); i++) { +vector get_div(ll x) { + vector ret; + for (ll i = 1; i * i <= x; i++) { if (x % i == 0) { - di.push_back(i); - if (x / i != i) di.push_back(x / i); + ret.push_back(i); + if (x / i != i) ret.push_back(x / i); } } - sort(di.begin(), di.end()); + sort(all(ret)); + return ret; } - // 2. Finding Prime Factorization in O(sqrt(x)) // INPUT: Given a natural number x. // OUTPUT: Find the result of the prime factorization of x. // TIME COMPLEXITY: O(sqrt(x)). -vector p; -void primeFactorization(int x) { - while (x % 2 == 0) { - x /= 2; - p.push_back(2); - } +vector factorize(ll x) { + vector ret; + while (x % 2 == 0) + x /= 2, ret.push_back(2); for (int i = 3; i <= sqrt(x); i += 2) { - while (x % i == 0) { - x /= i; - p.push_back(i); - } + while (x % i == 0) + x /= i, ret.push_back(i); } - if (x > 1) p.push_back(x); + if (x > 1) ret.push_back(x); + return ret; } \ No newline at end of file diff --git a/src/6-math/1_modint.hpp b/src/6-math/1_modint.hpp new file mode 100644 index 0000000..d323362 --- /dev/null +++ b/src/6-math/1_modint.hpp @@ -0,0 +1,69 @@ +#pragma once +#include "../common/common.hpp" +template +class modint { + static_assert(m > 0, "Modulus must be positive"); + +public: + static constexpr int mod() { return m; } + constexpr modint(ll y = 0) : x(y >= 0 ? y % m : (y % m + m) % m) {} + constexpr int val() const { return x; } + constexpr modint &operator+=(const modint &r) { + if ((x += r.x) >= m) x -= m; + return *this; + } + constexpr modint &operator-=(const modint &r) { + if ((x += m - r.x) >= m) x -= m; + return *this; + } + constexpr modint &operator*=(const modint &r) { + x = static_cast(1LL * x * r.x % m); + return *this; + } + constexpr modint &operator/=(const modint &r) { return *this *= r.inv(); } + constexpr bool operator==(const modint &r) const { return x == r.x; } + constexpr modint operator+() const { return *this; } + constexpr modint operator-() const { return modint{-x}; } + constexpr friend modint operator+(const modint &l, const modint &r) { + return modint{l} += r; + } + constexpr friend modint operator-(const modint &l, const modint &r) { + return modint{l} -= r; + } + constexpr friend modint operator*(const modint &l, const modint &r) { + return modint{l} *= r; + } + constexpr friend modint operator/(const modint &l, const modint &r) { + return modint{l} /= r; + } + constexpr modint inv() const { + int a = x, b = m, u = 1, v = 0; + while (b > 0) { + int t = a / b; + swap(a -= t * b, b); + swap(u -= t * v, v); + } + return modint{u}; + } + constexpr modint pow(ll n) const { + modint ret(1), mul(x); + while (n > 0) { + if (n & 1) ret *= mul; + mul *= mul; + n >>= 1; + } + return ret; + } + friend ostream &operator<<(ostream &os, const modint &r) { + return os << r.x; + } + friend istream &operator>>(istream &is, modint &r) { + ll t; + is >> t; + r = modint{t}; + return is; + } + +private: + int x; +}; diff --git a/src/6-math/euclidean_algorithms.cpp b/src/6-math/2_extgcd.hpp similarity index 81% rename from src/6-math/euclidean_algorithms.cpp rename to src/6-math/2_extgcd.hpp index daf9354..862d713 100644 --- a/src/6-math/euclidean_algorithms.cpp +++ b/src/6-math/2_extgcd.hpp @@ -1,3 +1,4 @@ +#pragma once #include "../common/common.hpp" // GCD, LCM @@ -40,8 +41,6 @@ ll lcm(ll a, ll b) { // TIME COMPLEXITY: O(log(AB)) -// BOJ 14565 AC Code -// https://www.acmicpc.net/problem/14565 pair egcd(ll a, ll b) { ll s = 0, olds = 1; ll t = 1, oldt = 0; @@ -59,21 +58,22 @@ pair egcd(ll a, ll b) { // oldr = gcd(a, b) return {{olds, oldt}, oldr}; } -ll linearCongruence(ll a, ll b, ll n) { // Find x such that ax = b (mod n). +ll linear_congruence(ll a, ll b, ll n) { // Find x such that ax = b (mod n). pair res = egcd(a, n); ll g = res.sc; // ax + ny = b has a solution iff gcd(a,n) | b. if (b % g) return -1; return (res.fr.fr * (b / g) % n + n) % n; } -ll modInv(ll a, ll p) { // Find x such that ax = 1 (mod p). - pair res = egcd(a, p); - // Modular inverse exists iff gcd(a, p) = 1. - if (res.sc == 1) return (res.fr.fr + p) % p; - else return -1; +ll inv_mod(ll a, ll p) { // Find x such that ax = 1 (mod p). + ll b = p, u = 1, v = 0; + while (b) { + ll t = a / b; + a -= t * b; + swap(a, b); + u -= t * v; + swap(u, v); + } + if (u < 0) u += p; + return u; } -int main() { - ll N, A; - cin >> N >> A; - cout << N - A << ' ' << modInv(A, N); -} \ No newline at end of file diff --git a/src/6-math/3_ntt.hpp b/src/6-math/3_ntt.hpp new file mode 100644 index 0000000..53a67b4 --- /dev/null +++ b/src/6-math/3_ntt.hpp @@ -0,0 +1,111 @@ +#pragma once +#include "../common/common.hpp" +#include "1_modint.hpp" +#include "2_extgcd.hpp" +constexpr int get_primitive_root(int mod) { + if (mod == 167772161) return 3; + if (mod == 469762049) return 3; + if (mod == 754974721) return 11; + if (mod == 998244353) return 3; + if (mod == 1224736769) return 3; +} +template +void ntt(vector &a) { + constexpr int mod = mint::mod(); + constexpr mint primitive_root = get_primitive_root(mod); + + const int n = sz(a); + for (int m = n; m > 1; m >>= 1) { + mint omega = primitive_root.pow((mod - 1) / m); + for (int s = 0; s < n / m; ++s) { + mint w = 1; + for (int i = 0; i < m / 2; ++i) { + mint l = a[s * m + i]; + mint r = a[s * m + i + m / 2]; + a[s * m + i] = l + r; + a[s * m + i + m / 2] = (l - r) * w; + w *= omega; + } + } + } +} +template +void intt(vector &a) { + constexpr int mod = mint::mod(); + constexpr mint primitive_root = get_primitive_root(mod); + + const int n = sz(a); + for (int m = 2; m <= n; m <<= 1) { + mint omega = primitive_root.pow((mod - 1) / m).inv(); + for (int s = 0; s < n / m; ++s) { + mint w = 1; + for (int i = 0; i < m / 2; ++i) { + mint l = a[s * m + i]; + mint r = a[s * m + i + m / 2] * w; + a[s * m + i] = l + r; + a[s * m + i + m / 2] = l - r; + w *= omega; + } + } + } +} +template +vector convolution(vector a, vector b) { + const int size = sz(a) + sz(b) - 1; + int n = 1; + while (n < size) n <<= 1; + a.resize(n), b.resize(n); + ntt(a), ntt(b); + for (int i = 0; i < n; ++i) a[i] *= b[i]; + intt(a); + a.resize(size); + mint n_inv = mint(n).inv(); + for (int i = 0; i < size; ++i) a[i] *= n_inv; + return a; +} +vector convolution_ll(const vector &a, + const vector &b) { + static constexpr ll m0 = 167772161; + static constexpr ll m1 = 469762049; + static constexpr ll m2 = 754974721; + static constexpr ll m01 = m0 * m1; + static constexpr __int128_t m012 = (__int128_t)m01 * m2; + + static const ll inv_m0_mod_m1 = inv_mod(m0 % m1, m1); + static const ll inv_m01_mod_m2 = inv_mod(m01 % m2, m2); + + using mint0 = modint; + using mint1 = modint; + using mint2 = modint; + + vector a0(all(a)), b0(all(b)); + vector a1(all(a)), b1(all(b)); + vector a2(all(a)), b2(all(b)); + + auto c0 = convolution(a0, b0); + auto c1 = convolution(a1, b1); + auto c2 = convolution(a2, b2); + + vector c(sz(c0)); + for (int i = 0; i < sz(c1); ++i) { + ll r0 = c0[i].val(); + ll r1 = c1[i].val(); + ll r2 = c2[i].val(); + + ll t1 = (r1 - r0) % m1; + if (t1 < 0) t1 += m1; + ll k1 = (ll)((__int128_t)t1 * inv_m0_mod_m1 % m1); + __int128_t x01 = r0 + (__int128_t)m0 * k1; + + ll x01_mod_m2 = (ll)(x01 % m2); + ll t2 = (r2 - x01_mod_m2) % m2; + if (t2 < 0) t2 += m2; + ll k2 = (ll)((__int128_t)t2 * inv_m01_mod_m2 % m2); + __int128_t x = x01 + (__int128_t)m01 * k2; + + if (x >= m012 / 2) x -= m012; + + c[i] = (ll)x; + } + return c; +} diff --git a/src/6-math/fft.cpp b/src/6-math/4_fft.hpp similarity index 62% rename from src/6-math/fft.cpp rename to src/6-math/4_fft.hpp index 7c75d35..4e2800d 100644 --- a/src/6-math/fft.cpp +++ b/src/6-math/4_fft.hpp @@ -77,48 +77,3 @@ vector multiply_mod(const vector &a, const vector &b, const ull mod) return ret; } } // namespace fft -namespace ntt { -constexpr ll MOD = (119 << 23) + 1, root = 3; // = 998244353 -// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21 -// and 483 << 21 (same root). The last two are > 10^9. -ll modpow(ll b, ll e) { - ll ans = 1; - for (; e; b = b * b % MOD, e /= 2) - if (e & 1) ans = ans * b % MOD; - return ans; -} -void ntt(vector &a) { - int n = sz(a), L = 31 - __builtin_clz(n); - static vector rt(2, 1); - for (static int k = 2, s = 2; k < n; k *= 2, s++) { - rt.resize(n); - ll z[] = {1, modpow(root, MOD >> s)}; - for (int i = k; i < 2 * k; i++) - rt[i] = rt[i / 2] * z[i & 1] % MOD; - } - vector rev(n); - for (int i = 0; i < n; i++) rev[i] = (rev[i / 2] | (i & 1) << L) / 2; - for (int i = 0; i < n; i++) - if (i < rev[i]) swap(a[i], a[rev[i]]); - for (int k = 1; k < n; k *= 2) - for (int i = 0; i < n; i += 2 * k) - for (int j = 0; j < k; j++) { - ll z = rt[j + k] * a[i + j + k] % MOD, &ai = a[i + j]; - a[i + j + k] = ai - z + (z > ai ? MOD : 0); - ai += (ai + z >= MOD ? z - MOD : z); - } -} -vector multiply(const vector &a, const vector &b) { - if (a.empty() || b.empty()) return {}; - int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), - n = 1 << B; - int inv = modpow(n, MOD - 2); - vector L(a), R(b), out(n); - L.resize(n), R.resize(n); - ntt(L), ntt(R); - for (int i = 0; i < n; i++) - out[-i & (n - 1)] = (ll)L[i] * R[i] % MOD * inv % MOD; - ntt(out); - return {out.begin(), out.begin() + s}; -} -} // namespace ntt \ No newline at end of file diff --git a/src/6-math/5_sieve.hpp b/src/6-math/5_sieve.hpp new file mode 100644 index 0000000..b964411 --- /dev/null +++ b/src/6-math/5_sieve.hpp @@ -0,0 +1,47 @@ +#pragma once +#include "../common/common.hpp" +// 1. Sieve of Eratosthenes +// TIME COMPLEXITY: O(Nlog(log(N))) +namespace sieve { +constexpr int N = 1010101; +bool is_prime[N]; +vector prime(1, 2); +void get_prime() { + fill(is_prime + 2, is_prime + N, 1); + for (ll i = 4; i < N; i += 2) + is_prime[i] = 0; + for (ll i = 3; i < N; i += 2) { + if (!is_prime[i]) continue; + prime.push_back(i); + for (ll j = i * i; j < N; j += i * 2) + is_prime[j] = 0; + } +} +}; // namespace sieve +// 2. Linear Sieve +namespace linear_sieve { +constexpr int N = 5050505; +vector sp(N); +vector prime; +void linear_sieve() { // Determine prime numbers between 1 and MAXN in O(MAXN) + for (int i = 2; i < N; i++) { + if (!sp[i]) { + prime.push_back(i); + sp[i] = i; + } + for (auto j : prime) { + if (i * j >= N) break; + sp[i * j] = j; + if (i % j == 0) break; + } + } +} +vector factorize(int x) { // factorization in O(log x) + vector ret; + while (x > 1) { + ret.push_back(sp[x]); + x /= sp[x]; + } + return ret; +} +}; // namespace linear_sieve diff --git a/src/6-math/6_binom.hpp b/src/6-math/6_binom.hpp new file mode 100644 index 0000000..c1452c1 --- /dev/null +++ b/src/6-math/6_binom.hpp @@ -0,0 +1,21 @@ +#pragma once +#include "../common/common.hpp" +namespace binom { +// nCr mod p in O(1) +constexpr int MOD = 1e9 + 7; +constexpr int N = 4040404; +ll fac[N], inv[N], facInv[N]; +void build() { // You must run build() before you call binom(int n, int r). + fac[0] = fac[1] = inv[1] = 1; + facInv[0] = facInv[1] = 1; + for (int i = 2; i < N; i++) { + fac[i] = i * fac[i - 1] % MOD; + inv[i] = -(MOD / i) * inv[MOD % i] % MOD; + if (inv[i] < 0) inv[i] += MOD; + facInv[i] = facInv[i - 1] * inv[i] % MOD; + } +} +ll binom(int n, int r) { + return fac[n] * facInv[r] % MOD * facInv[n - r] % MOD; +} +}; // namespace binom diff --git a/src/6-math/binomial_coefficient.cpp b/src/6-math/binomial_coefficient.cpp deleted file mode 100644 index ba7eee0..0000000 --- a/src/6-math/binomial_coefficient.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "../common/common.hpp" - -// nCr in O(r) -// Beware of integer overflow -ll binom(int n, int r) { - if (r < 0 || n < r) return 0; - r = min(r, n - r); - ll ret = 1; - for (ll i = 1; i <= r; i++) { - ret *= n + 1 - i; - ret /= i; - } - return ret; -} -// nCr (Pascal’s Rule) -ll binomDP[1010][1010]; -void init() { - for (int i = 0; i < 1010; i++) { - for (int j = 0; j < 1010; j++) { - binomDP[i][j] = -1; - } - } -} -ll binom(int n, int r) { - if (r < 0 || n < r) return 0; - ll &ret = binomDP[n][r]; - if (ret != -1) return ret; - if (n == 1) return ret = 1; - return binom(n - 1, r - 1) + binom(n - 1, r); -} -// nCr mod p in O(1) -const int MOD = 1e9 + 7; -const int MAXN = 4040404; -ll fac[MAXN], inv[MAXN], facInv[MAXN]; -ll binom(int n, int r) { - return fac[n] * facInv[r] % MOD * facInv[n - r] % MOD; -} -int main() { - cin.tie(NULL), cout.tie(NULL); - ios_base::sync_with_stdio(false); - // Preprocessing in O(N) - fac[0] = fac[1] = inv[1] = 1; - facInv[0] = facInv[1] = 1; - for (int i = 2; i < MAXN; i++) { - fac[i] = i * fac[i - 1] % MOD; - inv[i] = -(MOD / i) * inv[MOD % i] % MOD; - if (inv[i] < 0) inv[i] += MOD; - facInv[i] = facInv[i - 1] * inv[i] % MOD; - } - // Answer each query in O(1) - int q; - cin >> q; - while (q--) { - int n, r; - cin >> n >> r; - cout << binom(n, r) << '\n'; - } -} \ No newline at end of file diff --git a/src/6-math/sieve.cpp b/src/6-math/sieve.cpp deleted file mode 100644 index c96bf79..0000000 --- a/src/6-math/sieve.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "../common/common.hpp" - -// Sieve of Eratosthenes -// TIME COMPLEXITY: O(Nlog(log(N))) -const int MAX = 1e6; -bool isPrime[MAX + 1]; -vector prime(1, 2); -void getPrime() { - fill(isPrime + 2, isPrime + MAX + 1, 1); - for (ll i = 4; i <= MAX; i += 2) - isPrime[i] = 0; - for (ll i = 3; i <= MAX; i += 2) { - if (!isPrime[i]) continue; - prime.push_back(i); - for (ll j = i * i; j <= MAX; j += i * 2) - isPrime[j] = 0; - } -} - -// Linear Sieve -const int MAXN = 5000000; -vector sp(MAXN + 1); -vector prime; -// Determine prime numbers between 1 and MAXN in O(MAXN) -void linearSieve() { - for (int i = 2; i <= MAXN; i++) { - if (!sp[i]) { - prime.push_back(i); - sp[i] = i; - } - for (auto j : prime) { - if (i * j > MAXN) break; - sp[i * j] = j; - if (i % j == 0) break; - } - } -} -// factorization in O(log x) -void factorization(int x) { - while (x > 1) { - cout << sp[x] << ' '; - x /= sp[x]; - } - cout << '\n'; -} -int main() { - linearSieve(); - int n; - cin >> n; - while (n--) { - int x; - cin >> x; - factorization(x); - } -} \ No newline at end of file