多项式的卷积

卷积

概念

令 $A,B,C$ 分别为三个多项式

形如 $c[k]=\sum_{i\oplus j=k}a[i]b[j]$ 的式子称为卷积,其中 $\oplus$ 为某种运算

注意到多项式的乘法就是加法卷积

化成卷积形式的小技巧

  1. 减法卷积,下面默认 $A,B,C$ 的长度都是 $n$

    $C_k=\sum_{i=k}^nA_{i-k}B_i=\sum_{i=0}^{n-k}A_{n-k-i}B’_{i}$,其中 $B’_k=B_{n-k}$

    证明:

FFT

简介

$FFT$ 最早用于解决多项式乘法,也就是加法卷积

可以将 $O(n^2)$ 的多项式乘法加速到 $O(n\log n)$

与多项式有关的几个概念和约定

多项式的系数表示法:$A(x)=\sum_{i=0}^{n-1}a_ix^i$,次数为 $n$

多项式的点值表示法:$(x_1,y_1),(x_2,y_2),\cdots ,(x_n, y_n)$,一个 $n$ 次多项式可以被 $n$ 个点所唯一确定

单位根

在复平面中,$x$ 表示实数,$y$ 表示虚数,从原点 $(0,0)$ 到 $(a,b)$ 的向量表示复数 $a+bi$

模长:复数 $a+bi$ 的模长为 $\sqrt {a^2+b^2}$,同时这也是其向量的长度

幅角:假设以逆时针为正方向,从 $x$ 轴正半轴到已知向量的转角的有向角叫做幅角

在代数中 $z^n=1$,则称 $z$ 为 $n$ 次单位根,我们将其扩展到复数域,用 $\omega_{n}$ 表示 $n$ 次单位根

至于 $\omega_n$ 的值,我们考虑用欧拉公式计算

然后我们能够发现关于单位根的三个显然的性质

  1. $w_{dn}^{dk}=w_n^k$
  2. $w_n^{k+\frac{n}{2}}=-w_n^k$,此处 $2|n$
  3. $\sum_{k=0}^{n-1}w_n^{k}=[n=1]$

FFT的原理

能够发现用过系数表示的两个 $n$ 次多项式相乘的是时间复杂度为 $O(n^2)$

而用点值表示的两个 $n$ 次多项式相乘的是时间复杂度为 $O(n\log n)$

那么我们考虑是否存在低于 $O(n^2)$ 的算法可以将一个多项式在系数表达和点值表达之间转换呢?

我们考虑将 $n$ 次单位根作为点值表达,然后考虑化简

不妨假设 $n$ 为 $2$ 的幂,其中 $k$ 是表示我们要计算将 $\omega_{n}^k$ 带入的值,不妨令 $k<\frac{n}{2}$

能否发现我们将 $n$ 的规模缩小了一半,且形式完全一样,所以我们可以递归计算,时间复杂度 $O(n\log n)$

现在还剩一个问题,我们还需要将点值表达转换为系数表达

考虑这样一个矩阵

我们现在知道 $y_0$ 到 $y_{n-1}$

那么我们只需要乘前面那个矩阵的逆矩阵就能得到 $a_0$ 到 $a_{n-1}$ 了

首先前面那个矩阵的第 $i$ 行第 $j$ 列的元素是 $w_n^{ij}$,假定从 $0$ 到 $n-1$ 编号

通过一些不为人知的方法 我们得到逆矩阵是 $\frac{w_n^{-ij}}{n}$

我们来验证一下,$a_{i,i}=\sum_{k=0}^{n-1}w_{n}^{ik}\frac{w_n^{-ik}}{n}=1$

$a_{i,j}=\sum_{k=0}^{n-1}w_n^{ik}\frac{w_n^{-jk}}{n}=\frac{1}{n}w_n^{i-j}\sum_{k=0}^{n-1}w_n^{k}$

我们知道 $\sum_{k=0}^{n-1}w_n^k=[n=1]$ 那么 $a_{i,j}=0$

所以现在我们的系数变成 $w_n^{-k}$ 了,再做一次 FFT 就好了

FFT的实现

递归

1
2
3
4
5
6
7
8
9
10
11
void FFT(int n, Complex *a, int type) {
if (n == 1) return ;
Complex a1[n / 2], a2[n / 2];
for (int i = 0; i < n / 2; ++i) a1[i] = a[2 * i], a2[i] = a[2 * i + 1];
FFT(n / 2, a1, type); FFT(n / 2, a2, type);
Complex wn = { cos(2 * pi / n), type * sin(2 * pi / n) }, w = { 1, 0 };
for (int i = 0; i < n / 2; ++i, w = w * wn) {
a[i] = a1[i] + w * a2[i];
a[i + n / 2] = a1[i] - w * a2[i];
}
}

注意到递归版本的常数太大,我们考虑使用非递归的版本

大概就是直接将所有系数换到递归底层的位置,然后合并即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
struct Complex {
double x, y;

friend Complex operator + (const Complex &u, const Complex &v) { return { u.x + v.x, u.y + v.y }; }
friend Complex operator - (const Complex &u, const Complex &v) { return { u.x - v.x, u.y - v.y }; }
friend Complex operator * (const Complex &u, const Complex &v) {
return { u.x * v.x - u.y * v.y, u.x * v.y + u.y * v.x };
}
} A[maxn], B[maxn], b[maxn];

int P[maxn];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}

void FFT(Complex *a, int n, int type) {
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
Complex wn = { cos(2 * pi / i), type * sin(2 * pi / i) };
for (int j = 0; j < n; j += i) {
Complex w = { 1, 0 };
for (int k = 0; k < m; ++k, w = w * wn) {
Complex t1 = a[j + k], t2 = a[j + k + m] * w;
a[j + k] = t1 + t2; a[j + k + m] = t1 - t2;
}
}
}
if (type < 0) for (int i = 0; i < n; ++i) a[i].x /= n;
}

void Mul(Complex *a, Complex *B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n2; ++i) b[i] = B[i];
fill(a + n1, a + n, Complex { 0, 0 }); fill(b + n2, b + n, Complex { 0, 0 });
FFT(a, n, 1); FFT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
FFT(a, n, -1);
}

精度比较好的写法,开long double可以跑1e15

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
typedef std::vector<int> Poly;
namespace Pol {
//const int N = 4200000; // 200MB
const int N = 2100000; // 100MB
const double pi = acos(-1);
struct Complex {
double x, y;

Complex(double x = 0, double y = 0) : x(x), y(y) {}

friend Complex operator + (const Complex &u, const Complex &v) { return Complex(u.x + v.x, u.y + v.y); }
friend Complex operator - (const Complex &u, const Complex &v) { return Complex(u.x - v.x, u.y - v.y); }
friend Complex operator * (const Complex &u, const Complex &v) {
return Complex(u.x * v.x - u.y * v.y, u.x * v.y + u.y * v.x);
}
} a[N], b[N];

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<Complex> init_W(int n) {
vector<Complex> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
Complex wn(cos(pi / i), sin(pi / i));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = w1[j] * wn;
}
return w;
} auto w = init_W(1 << 21);
void DIT(Complex *a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k];
a[i + j + k] = (x - y) * w[k + j], a[i + j] = x + y;
}
}
void DIF(Complex *a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k] * w[k + j];
a[i + j + k] = x - y, a[i + j] = x + y;
}
for (int i = 0; i < n; ++i) a[i].x /= 4 * n, a[i].y /= 4 * n;
//for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
reverse(a + 1, a + n);
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) { // 三次变两次
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
fill(a, a + n, Complex(0, 0));
for (int i = 0; i < n1; ++i) a[i].x += A[i], a[i].y += A[i];
for (int i = 0; i < n2; ++i) a[i].x += B[i], a[i].y -= B[i];
//for (int i = 0; i < max(n1, n2); ++i) a[i] = Complex(A[i] + B[i], A[i] - B[i]);
DIT(a, n);
for (int i = 0; i < n; ++i) a[i] = a[i] * a[i];
DIF(a, n); Poly ans(n1 + n2 - 1); for (int i = 0; i < n1 + n2 - 1; ++i) ans[i] = (int) (a[i].x + 0.5);
return ans;
}
/*Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = Complex(A[i], 0);
for (int i = 0; i < n2; ++i) b[i] = Complex(B[i], 0);
fill(a + n1, a + n, Complex(0, 0)); fill(b + n2, b + n, Complex(0, 0));
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
DIF(a, n); Poly ans(n1 + n2 - 1); for (int i = 0; i < n1 + n2 - 1; ++i) ans[i] = (int) (a[i].x + 0.5);
return ans;
}*/
Poly MMul(const Poly &A, const Poly &B, int len) { // 减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = Complex(A[i], 0);
for (int i = 0; i < len; ++i) b[i] = Complex(B[i], 0);
fill(a + len, a + n, Complex(0, 0)); fill(b + len, b + n, Complex(0, 0)); reverse(b, b + len);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
DIF(a, n); Poly ans(len); for (int i = 0; i < len; ++i) ans[i] = (int) (a[i].x + 0.5);
reverse(ans.begin(), ans.end()); return ans;
}
} // namespace Pol

NTT

$FFT$ 使用的是复数,不仅运算慢,而且还有可能产生精度损失

所以我们考虑取模意义下的一种新的单位根

经过前人的努力,我们得知这种数是原根

实际上 $w_n^k=g^{\frac{p-1}{n}}$,这也就表示了 $p$ 必须能被 $2$ 大次幂的整除

常用的模数是 $998244353=119*2^{23}+1$ 原根是 $3$

还有 $1004535809$ 原根是 $3$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#define ll long long
using namespace std;

const int p = 998244353;
ll pow_mod(ll x, ll n) {
ll s = 1;
for (; n; n >>= 1, x = x * x % p)
if (n & 1) s = s * x % p;
return s;
}

typedef std::vector<int> Poly;
namespace Pol {
inline int add(int a, int b) { return (a += b) >= p ? a -= p : a; }
inline int mul(int a, int b) { return 1ll * a * b % p; }

const int N = 4200000;
int a[N], b[N];

const int G = 3;
const int Gi = pow_mod(G, p - 2);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void NTT(int *a, int n, int type) {
static int w[N];
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
int wn = pow_mod(type > 0 ? G : Gi, (p - 1) / i);
w[0] = 1; for (int j = 1; j < m; ++j) w[j] = mul(wn, w[j - 1]);
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
int t1 = a[j + k], t2 = mul(a[j + k + m], w[k]);
a[j + k] = add(t1, t2);
a[j + k + m] = add(t1, p - t2);
}
}
if (type < 0) {
ll inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = inv * a[i] % p;
}
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = add(A[i], p);
for (int i = 0; i < n2; ++i) b[i] = add(B[i], p);
fill(a + n1, a + n, 0), fill(b + n2, b + n, 0);
NTT(a, n, 1); NTT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
NTT(a, n, -1); Poly ans(n1 + n2 - 1); copy_n(a, n1 + n2 - 1, ans.begin());
return ans;
}
Poly MMul(const Poly &A, const Poly &B, int len) { //减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = add(A[i], p);
for (int i = 0; i < len; ++i) b[i] = add(B[i], p);
fill(a + len, a + n, 0), fill(b + len, b + n, 0); reverse(b, b + len);
NTT(a, n, 1); NTT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
NTT(a, n, -1); Poly ans(len);
copy_n(a, len, ans.begin()); reverse(ans.begin(), ans.end());
return ans;
}
} // namespace Pol

比较快的 $NTT$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
typedef std::vector<int> Poly;
namespace Pol {
inline int add(int a, int b) { return (a += b) >= p ? a -= p : a; }
inline int mul(int a, int b) { return 1ll * a * b % p; }

const int N = 4200000;
int a[N], b[N];

const int G = 3;
const int Gi = pow_mod(G, p - 2);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<int> init_W(int n) {
vector<int> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
int wn = pow_mod(G, (p - 1) / (i << 1));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = mul(w1[j], wn);
}
return w;
} auto w = init_W(1 << 21);
void DIT(int *a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = a[i + j + k];
a[i + j + k] = mul(add(x, p - y), w[k + j]), a[i + j] = add(x, y);
}
}
void DIF(int *a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = mul(a[i + j + k], w[k + j]);
a[i + j + k] = add(x, p - y), a[i + j] = add(x, y);
}
int inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], inv);
reverse(a + 1, a + n);
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = add(A[i], p);
for (int i = 0; i < n2; ++i) b[i] = add(B[i], p);
fill(a + n1, a + n, 0), fill(b + n2, b + n, 0);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(n1 + n2 - 1); copy_n(a, n1 + n2 - 1, ans.begin());
return ans;
}
Poly MMul(const Poly &A, const Poly &B, int len) { //减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = add(A[i], p);
for (int i = 0; i < len; ++i) b[i] = add(B[i], p);
fill(a + len, a + n, 0), fill(b + len, b + n, 0); reverse(b, b + len);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(len);
copy_n(a, len, ans.begin()); reverse(ans.begin(), ans.end());
return ans;
}
} // namespace Pol

不知道为什么有奇怪错误的 $NTT$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
typedef std::vector<int> Poly;
typedef unsigned long long ull;
namespace Pol {
int add(int x, int y) { return (x += y) >= p ? x -= p : x; }
int sub(int x, int y) { return (x -= y) < 0 ? x += p : x; }

const int N = 4200000;
ull tmp[N]; int gw[N], a[N], b[N];

void DIT(int *a, int n, bool flag) {
for (int i = 0; i < n; ++i) tmp[i] = a[i];
for (int l = 1; l < n; l <<= 1) {
ull *k = tmp;
for (int i = 0; i < n; i += (l << 1), k += (l << 1)) {
ull *x = k;
for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
ull o = x[l] * *g % p;
x[l] = *x + p - o, *x += o;
}
}
}
int inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = tmp[i] % p * inv % p;
reverse(a + 1, a + n);
}
void DIF(int *a, int n, bool flag) {
for (int i = 0; i < n; ++i) tmp[i] = a[i];
for (int l = n / 2; l >= 1; l >>= 1) {
ull *k = tmp;
for (int i = 0; i < n; i += (l << 1), k += (l << 1)) {
ull *x = k;
for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
ull o = x[l] % p;
x[l] = (*x + p - o) * *g % p, *x += o;
}
}
}
for (int i = 0; i < n; ++i) a[i] = tmp[i] % p;
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n<<= 1;
copy_n(A.begin(), n1, a), fill(a + n1, a + n, 0);
copy_n(B.begin(), n2, b), fill(b + n2, b + n, 0);
DIF(a, n, false), DIF(b, n, false);
for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % p;
DIT(a, n, true);
Poly ans(n1 + n2 - 1);
copy_n(a, n1 + n2 - 1, ans.begin());
return ans;
}
void init() {
for (int l = 1; l < (1 << 21); l <<= 1) {
gw[l] = 1;
int gn = pow_mod(3, (p - 1) / (l << 1));
for (int j = 1; j < l; ++j) {
gw[l | j] = 1ll * gw[l | (j - 1)] * gn % p;
}
}
}
} // namespace Pol

不知道为什么偶尔会出错的 $NTT$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
typedef std::vector<int> Poly;
namespace Pol {
inline int add(int a, int b) { return (a += b) >= p ? a -= p : a; }
inline int sub(int a, int b) { return (a -= b) < 0 ? a += p : a; }
inline void inc(int &a, int b) { (a += b) >= p ? a -= p : a; }
inline void dec(int &a, int b) { (a -= b) < 0 ? a += p : a; }

const int N = 3000000;

ull tmp[N]; int gw[N], a[N], b[N];
void init(int n = N - 10) {
int t = 1;
while ((1 << t) < n) ++t;
t = min(t - 1, 21);
gw[0] = 1, gw[1 << t] = pow_mod(31, 1 << (21 - t));
for (int i = t; i; --i) gw[1 << (i - 1)] = 1ll * gw[1 << i] * gw[1 << i] % p;
for (int i = 1; i < (1 << t); ++i) gw[i] = 1ll * gw[i & (i - 1)] * gw[i & -i] % p;
}
void DIT(int *a, int n) {
for (int i = 0; i < n; ++i) tmp[i] = a[i];
for (int l = 1; l < n; l <<= 1) {
ull *k = tmp;
for (int *g = gw; k < tmp + n; k += (l << 1), ++g) {
for (ull *x = k; x < k + l; ++x) {
int o = x[l] % p;
x[l] = 1ll * (*x + p - o) * *g % p, *x += o;
}
}
}
int iv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = 1ll * tmp[i] % p * iv % p;
reverse(a + 1, a + n);
}
void DIF(int *a, int n) {
for (int i = 0; i < n; ++i) tmp[i] = a[i];
for (int l = n / 2; l >= 1; l >>= 1) {
ull *k = tmp;
for (int *g = gw; k < tmp + n; k += (l << 1), ++g) {
for (ull *x = k; x < k + l; ++x) {
int o = 1ll * x[l] * *g % p;
x[l] = *x + p - o, *x += o;
}
}
}
for (int i = 0; i < n; ++i) a[i] = tmp[i] % p;
}
Poly mult(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1;
while (n < n1 + n2 - 1) n <<= 1;
copy_n(A.begin(), n1, a), fill(a + n1, a + n, 0);
copy_n(B.begin(), n2, b), fill(b + n2, b + n, 0);
DIF(a, n), DIF(b, n);
for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % p;
DIT(a, n);
Poly res(n1 + n2 - 1);
copy_n(a, n1 + n2 - 1, res.begin());
return res;
}
}

三模NTT

注意精度在 $10^{26}$ 左右, 正常多项式乘法的极值是 $10^{23}$ 足以通过

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
int p;
ll pow_mod(ll x, ll n, int p) {
ll s = 1;
for (; n; n >>= 1, x = x * x % p)
if (n & 1) s = s * x % p;
return s;
}

const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const ll mod12 = 1ll * mod1 * mod2;
const int inv1 = pow_mod(mod1, mod2 - 2, mod2), inv2 = pow_mod(mod12 % mod3, mod3 - 2, mod3);
struct Int {
int x, y, z;

Int() { x = y = z = 0; }
Int(int x, int y, int z) : x(x), y(y), z(z) {}
Int(int v) : x(v), y(v), z(v) {}

static inline int add(int x, int y, int p) { return (x += y) >= p ? x - p : x; }

inline friend Int operator + (const Int &u, const Int &v) {
return Int(add(u.x, v.x, mod1), add(u.y, v.y, mod2), add(u.z, v.z, mod3));
}
inline friend Int operator - (const Int &u, const Int &v) {
return Int(add(u.x, mod1 - v.x, mod1), add(u.y, mod2 - v.y, mod2), add(u.z, mod3 - v.z, mod3));
}
inline friend Int operator * (const Int &u, const Int &v) {
return Int(1ll * u.x * v.x % mod1, 1ll * u.y * v.y % mod2, 1ll * u.z * v.z % mod3);
}

inline int get() const {
ll v = 1ll * add(y, mod2 - x, mod2) * inv1 % mod2 * mod1 + x;
return (1ll * add(z, mod3 - v % mod3, mod3) * inv2 % mod3 * (mod12 % p) % p + v) % p;
}
};

typedef std::vector<Int> Poly;
namespace Pol {
const int N = 4200000;
Int a[N], b[N], w[N];

const int Gi1 = pow_mod(G, mod1 - 2, mod1);
const int Gi2 = pow_mod(G, mod2 - 2, mod2);
const int Gi3 = pow_mod(G, mod3 - 2, mod3);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}

void NTT(Int *a, int n, int type) {
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
Int wn = Int(pow_mod(type > 0 ? G : Gi1, (mod1 - 1) / i, mod1),
pow_mod(type > 0 ? G : Gi2, (mod2 - 1) / i, mod2),
pow_mod(type > 0 ? G : Gi3, (mod3 - 1) / i, mod3));
w[0] = Int(1); for (int j = 1; j < m; ++j) w[j] = wn * w[j - 1];
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
Int t1 = a[j + k], t2 = a[j + k + m] * w[k];
a[j + k] = t1 + t2;
a[j + k + m] = t1 - t2;
}
}
if (type < 0) {
Int inv = Int(pow_mod(n, mod1 - 2, mod1), pow_mod(n, mod2 - 2, mod2), pow_mod(n, mod3 - 2, mod3));
for (int i = 0; i < n; ++i) a[i] = a[i] * inv;
}
}

Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
copy_n(A.begin(), n1, a), fill(a + n1, a + n, Int(0));
copy_n(B.begin(), n2, b), fill(b + n2, b + n, Int(0));
NTT(a, n, 1); NTT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
NTT(a, n, -1); Poly ans(n1 + n2 - 1);
for (int i = 0; i < n1 + n2 - 1; ++i) ans[i] = a[i].get();
return ans;
}
} // namespace Pol

MTT

四次 $FFT$,$1e9$ 内模数均可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#define Poly vector<int>
#define len(A) ((int) A.size())
namespace Pol {
const int N = 4200000;
const double pi = acos(-1);
struct Complex {
double x, y;

Complex(double x = 0, double y = 0) : x(x), y(y) {}

friend Complex operator + (const Complex &u, const Complex &v) { return Complex(u.x + v.x, u.y + v.y); }
friend Complex operator - (const Complex &u, const Complex &v) { return Complex(u.x - v.x, u.y - v.y); }
friend Complex operator * (const Complex &u, const Complex &v) {
return Complex(u.x * v.x - u.y * v.y, u.x * v.y + u.y * v.x);
}
}; typedef vector<Complex> Vcp;

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<Complex> init_W(int n) {
vector<Complex> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
Complex wn(cos(pi / i), sin(pi / i));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = w1[j] * wn;
}
return w;
} auto w = init_W(1 << 21);
void DIT(Vcp& a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k];
a[i + j + k] = (x - y) * w[k + j], a[i + j] = x + y;
}
}
void DIF(Vcp &a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k] * w[k + j];
a[i + j + k] = x - y, a[i + j] = x + y;
}
const double inv = 1. / n;
for (int i = 0; i < n; ++i) a[i].x *= inv, a[i].y *= inv;
reverse(a.begin() + 1, a.end());
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
Vcp a(n), b(n), c0(n), c1(n);
if (n1 == 1 && n2 == 1) return Poly { mul(A[0], B[0]) };
for (int i = 0; i < n1; ++i) a[i] = Complex(A[i] & 0x7fff, A[i] >> 15);
for (int i = 0; i < n2; ++i) b[i] = Complex(B[i] & 0x7fff, B[i] >> 15);
DIT(a, n), DIT(b, n);
for (int k = 1, i = 0, j; k < n; k <<= 1)
for (; i < k * 2; ++i) {
j = i ^ k - 1;
c0[i] = Complex(a[i].x + a[j].x, a[i].y - a[j].y) * b[i] * 0.5;
c1[i] = Complex(a[i].y + a[j].y, -a[i].x + a[j].x) * b[i] * 0.5;
}
DIF(c0, n), DIF(c1, n); Poly ans(n1 + n2 - 1);
for (int i = 0; i < n1 + n2 - 1; i++) {
ll c00 = c0[i].x + 0.5, c01 = c0[i].y + 0.5, c10 = c1[i].x + 0.5, c11 = c1[i].y + 0.5;
ans[i] = (c00 + ((c01 + c10) % p << 15) + (c11 % p << 30)) % p;
} return ans;
}
} // namespace Pol

模x^2乘法NTT

单位元 $(0,1)$,$(x,y)$ 逆元 $(-\frac{x}{y^2},\frac{1}{y})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
struct Int {
int x, y;

Int() { x = y = 0; }
Int(int x, int y) : x(x), y(y) {}

inline friend Int operator + (const Int &u, const Int &v) {
return Int(add(u.x, v.x), add(u.y, v.y));
}
inline friend Int operator - (const Int &u, const Int &v) {
return Int(add(u.x, p - v.x), add(u.y, p - v.y));
}
inline friend Int operator * (const Int &u, const Int &v) {
return Int(add(mul(u.x, v.y), mul(u.y, v.x)), mul(u.y, v.y));
}
inline friend Int operator / (const Int &u, const Int &v) {
ll inv = pow_mod(v.y, p - 2);
return Int(add(mul(u.x, v.y), p - mul(u.y, v.x)) * inv % p * inv % p,
mul(u.y, inv));
}
};

typedef std::vector<Int> Poly;
namespace Pol {
const int N = 4200000;
Int a[N], b[N], w[N];

const int G = 3;
const int Gi = pow_mod(G, p - 2);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}

void NTT(Int *a, int n, int type) {
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
Int wn = Int(0, pow_mod(type > 0 ? G : Gi, (p - 1) / i));
w[0] = Int(0, 1); for (int j = 1; j < m; ++j) w[j] = wn * w[j - 1];
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
Int t1 = a[j + k], t2 = a[j + k + m] * w[k];
a[j + k] = t1 + t2;
a[j + k + m] = t1 - t2;
}
}
if (type < 0) {
Int inv = Int(0, pow_mod(n, p - 2));
for (int i = 0; i < n; ++i) a[i] = a[i] * inv;
}
}

Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
copy_n(A.begin(), n1, a), fill(a + n1, a + n, Int(0, 0));
copy_n(B.begin(), n2, b), fill(b + n2, b + n, Int(0, 0));
NTT(a, n, 1); NTT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
NTT(a, n, -1); Poly ans(n1 + n2 - 1);
for (int i = 0; i < n1 + n2 - 1; ++i) ans[i] = a[i];
return ans;
}
} // namespace Pol

集合运算卷积

FWT

大概就是求这个东西 $C_k=\sum_{i\oplus j=k}A_iB_j$​,其中 $\oplus$​ 表示某中位运算

$FMT$ 一般用于解决 $or$ 和 $and$ 的情况

首先我们考虑 $or$ 的情况,类似 $FFT$ 的做法,我们考虑构造出一种类似点值表达式的东西

在 $or$ 中我们使用的是这个东西,$A’_i=\sum_{j|i=i}A_j$,这样我们就可以直接将 $A’$ 和 $B’$ 直接对应位相乘,然后再变换回去

实际上这个变换就是求一个高维前缀和,但是我们也可以参照 $FFT$ 的写法来实现,本质上是一样的

$and$​ 的做法同理

$FWT$ 用来解决 $xor$ 的问题,大概用不到过程吧,这里就不写了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
typedef std::vector<int> Poly;
namespace Pol {
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }

const int n = 20;
int a[1 << n], b[1 << n];

void FWT_or(int *a, int n, int type) {
for (int i = 2, m = 1; i <= n; m = i, i *= 2)
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
int t1 = a[j + k], t2 = a[j + k + m];
a[j + k] = t1; a[j + k + m] = type > 0 ? add(t1, t2) : add(p - t1, t2);
}
}

void FWT_and(int *a, int n, int type) {
for (int i = 2, m = 1; i <= n; m = i, i *= 2)
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
int t1 = a[j + k], t2 = a[j + k + m];
a[j + k] = type > 0 ? add(t1, t2) : add(t1, p - t2); a[j + k + m] = t2;
}
}

ll inv = pow_mod(2, p - 2);
void FWT_xor(int *a, int n, int type) {
for (int i = 2, m = 1; i <= n; m = i, i *= 2)
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
ll t1 = a[j + k], t2 = a[j + k + m];
a[j + k] = type > 0 ? add(t1, t2) : add(t1, t2) * inv % p;
a[j + k + m] = type > 0 ? add(t1, p - t2) : add(t1, p - t2) * inv % p;
}
}

Poly Mul(const Poly &A, const Poly &B, int n) {
copy_n(A.begin(), n, a); copy_n(B.begin(), n, b);
FWT(a, n, 1); FWT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % p;
FWT(a, n, -1);
Poly ans(n); copy_n(a, n, ans.begin());
return ans;
}
}

子集卷积

大概就是求 $C_{k}=\sum_{\\i ~or~j=k\wedge i ~and ~ j=0}A_iB_j$​

我们考虑一般的或卷积,注意到第二个限制相当于 $|i|+|j|=|i~or~j|$,所以我们再开一维记录大小即可

时间复杂度 $O(n\log^2 n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
typedef std::vector<int> Poly;
namespace Pol {
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }

const int n = 20;
int a[n + 1][1 << n], b[n + 1][1 << n], c[n + 1][1 << n], cnt[1 << n];

void FWT(int *a, int n, int type) {
for (int i = 2, m = 1; i <= n; m = i, i *= 2)
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
int t1 = a[j + k], t2 = a[j + k + m];
a[j + k] = t1; a[j + k + m] = type > 0 ? add(t1, t2) : add(p - t1, t2);
}
}

Poly Mul(const Poly &A, const Poly &B, int n, int m) {
for (int i = 0; i < n; ++i) cnt[i] = __builtin_popcount(i);
for (int i = 0; i <= m; ++i) fill(a[i], a[i] + n, 0), fill(b[i], b[i] + n, 0), fill(c[i], c[i] + n, 0);
for (int i = 0; i < n; ++i) a[cnt[i]][i] = A[i], b[cnt[i]][i] = B[i];
for (int i = 0; i <= m; ++i) FWT(a[i], n, 1), FWT(b[i], n, 1);
for (int i = 0; i <= m; ++i)
for (int j = 0; j <= i; ++j)
for (int k = 0; k < n; ++k) c[i][k] = add(c[i][k], 1ll * a[j][k] * b[i - j][k] % p);
for (int i = 0; i <= m; ++i) FWT(c[i], n, -1);
Poly ans(n); for (int i = 0; i < n; ++i) ans[i] = c[cnt[i]][i];
return ans;
}
}

分治FFT以及分治FWT

分治 $FFT$,我们给定 $f_0$,以及 $f_k=\sum_{i=0}^{k-1}f_ig_{k-i}$,求 $f$

1
2
3
4
5
6
7
8
void solve(Poly &A, const Poly &B, int l, int r) { // 区间左闭右开
if (l + 1 == r) return ;
int m = l + r >> 1; solve(A, B, l, m);
Poly t1(A.begin() + l, A.begin() + m), t2(B.begin(), B.begin() + r - l);
Poly t = Pol::Mul(t1, t2, m - l, r - l);
for (int i = m; i < r; ++i) A[i] = (A[i] + t[i - l]) % p;
solve(A, B, m, r);
}

分治 $FWT$​,我们给定 $f_0$​,以及 $f_k=\sum_{i=0}^{k-1}f_ig_{k\oplus i}$,其中 $\oplus$ 表示某一种位运算

1
2
3
4
5
6
7
8
9
10
11
12
13
void solve(int l, int r) { // 区间是左闭右开的形式,另外这个算的是先计算右半部分的
int m = l + r >> 1, M = (r - l) / 2 - 1;
if (l + 2 == r) {
A[l] = (A[l] + A[l + 1] * B[l ^ l + 1]) % p;
return ;
}
solve(m, r);
for (int i = m; i < r; ++i) t1[i & M] = A[i];
for (int i = 0; i <= M; ++i) t2[i] = B[i + M + 1];
Xor(t1, t2, M + 1);
for (int i = l; i < m; ++i) A[i] = (A[i] + t1[i & M]) % p;
solve(l, m);
}

例题

  1. 简要题意:对于 $S$ 的每个长度为 $|T|$ 的子串求是否和 $T$ 完美匹配,$S$ 和 $T$ 均含有通配符,$|S|\le 3\times 10^5$

    简要题解:构造函数 $(S_i-T_i)^2\times S_i \times T_i$

    利用这个函数直接将 $S$ 和翻转后的 $T$ 卷起来

    Luogu P4173 残缺的字符串

  2. 简要题意:对于 $S$ 的每个长度为 $|T|$​ 的子串中,有多少子串满足与 $T$ 串 $k,k\in [0,m]$ 匹配,$k$ 匹配定义为有至多 $k$​ 个位置不同

    简要题解:构造函数 $(S_i-T_i)^2\times S_i \times T_i$

    对于 $T$ 的每一种字符单独求失配个数

    2021杭电多校3 C Forgiving Matching

  3. 简要题意:现在有 $n$ 个人,第 $i$ 个人手里有一个数字 $i$,现在开始传数字,第 $i$ 个人会将他手里的数字传给 $p_i$,$p_i$ 是一个排列,同时现在有 $m$ 次询问,每次询问给出一个整数 $x$,求传了 $x$ 轮后,每个人的编号乘上他手里的数字的和

    $n\le 2\times 10^5,m\le 10^5,x\le 10^9$

    简要题解:注意到传球形成了若干个简单环,我们对每个简单环单独计算贡献

    我们设当前环的大小为 $n$,第 $i$ 个人的编号为 $a_i$,如果传了 $k$ 轮,那么总贡献是类似于 $\sum a_i\times a_{i+k}$ 的东西,稍作考虑后可以发现这是一个类似减法卷积的东西,将长度扩大到 $2n$,然后再稍作处理直接做减法卷积即可,注意到总贡献刚好是 $1e15$ 级别的,所以我们是可以用 $FFT$ 的

    然后考虑处理询问,注意到大小不同的环的个数只有 $O(\sqrt n)$,所以每次询问暴力所有环的大小即可

    时间复杂度 $O(n\log n+q\sqrt n)$

    第46屆ICPC 東亞洲區域賽(澳門) E Pass the Ball

  4. 简要题意:给定两个长度均为 $n$ 的序列 $a_i,b_i$,定义一个数对 $(a_i,b_j)$ 的价值为 $a_i\times b_j$,现在每个 $a_i$ 和 $b_i$ 都只能使用一次,求对于 $k\in[1,n]$,选出 $k$ 个数对的最大价值和以及最小价值和

    $n\le 10^5,-10^4\le a_i,b_i\le 10^4$

    我们只需要考虑最大价值和,求最小价值和只需要将其中一个序列全部取反用同样的方法做即可

    我们首先考虑如果没有负数,那么最大价值和一定是将 $a$ 和 $b$ 排序后对应位置匹配,如果有负数的话,能够发现,我们一定先凑价值为正的数对,即正正匹配,负负匹配,注意到这里的匹配也是按照绝对值的大小对应位置匹配

    我们现在考虑剩下的那一段,可以发现,一定是一个序列全负,一个序列全正,我们按照绝对值递增来排序,可以发现我们如果我们需要在这里选 $k$ 对,那么匹配一定是 $\sum_{i=0}^ka_ib_{k-i}$,很显然这是一个卷积的形式,观察一下权值范围为 $10^{13}$,可以直接使用 $FFT$,时间复杂度 $O(n\log n)$

    2021-2022 ACM-ICPC Latin American Regional Programming Contest B Because, Art!

  5. 简要题意:数轴上有 $n + 1$ 个点,分别为 $[0,n]$,你现在在 $0$,从 $i$ 走到 $i+1$ 需要花费 $c_i$ 的时间,有 $p_i$ 的成功概率,$(1-p_i)\frac{w_j}{\sum_{k=1}^iw_k}$ 的概率走到 $i-j(1\le j\le i)$,求走到 $n$ 的期望时间

    $n\le 10^5$

    简要题解:令 $f_i$ 为从 $i$ 走到 $n$ 的概率,容易得到转移为 $f_i=c_i+p_if_{i+1}+\frac{1-p_i}{\sum_{j=1}^iw_j}\sum_{j=0}^{i-1}f_jw_{i-j}$

    注意到后面的 $sum$ 很像是一个分治 $NTT$ 的形式,但是注意到我们的转移成环,但是对于这个形式,我们可以简单移项解决,我们可以得到 $f_{i+1}=\frac{1}{p_i}(f_i-c_i-\frac{1-p_i}{\sum_{j=1}^iw_j}\sum_{j=0}^{i-1}f_jw_{i-j})$,同时我们已知 $f_n=0$,要求 $f_0$,那么我们可以设 $f_i=a_if_0+b_i$,注意到 $a_i$ 和 $b_i$ 之间没有联系,可以直接拆成两个式子,这两个式子都是分治 $NTT$ 的形式,直接做即可,时间复杂度 $O(n\log^2n)$

    2022杭电多校3 A Equipment Upgrade

  6. 简要题意:给定一个长度为 $n$ 的序列 $a_i$,求 $\sum_{i=1}^n\sum_{j=i+1}(a_i\times a_j\bmod P)$,注意只对乘积取模,$P=200003$

    $n\le 2\times 10^5$

    简要题解:我们注意到如果不是只对成绩取模,那么显然可以直接 $FFT$,那么现在这种情况我们只能通过枚举值为 $x$ 和值为 $y$ 的乘积取模在乘上方案来计算答案

    注意到 $P$ 是素数存在原根,那么我们转为枚举 $g^x$ 和 $g^y$,我们知道 $g^x\times g^y=g^{x+y}$ 这个东西符合卷积的形式,使用 $FFT$ 即可,时间复杂度 $O(n\log n)$

    AGC 047C Product Modulo