Luogu P5900 无标号无根树计数

题目描述

https://www.luogu.com.cn/problem/P5900

简要题意:求 $n$ 个点的无标号无根树数量,答案对 $998244353$ 取模

$n \le 2\times 10^5$

Solution

首先我们考虑无标号有根树计数,令 $F(x)$ 为无标号有根树的生成函数

我们枚举根节点,容易得到 $F(x)=x\times Euler(F(x))=x\times exp(\sum_{n=1}\frac{F(x^n)}{n})$

我们考虑求 $F(x)$,容易想到使用牛顿迭代,我们令 $H(F_k(x))=F_k(x)-x\times exp(\sum_{n=1}\frac{F_k(x^n)}{n})\equiv 0(\bmod x^{2^k})$,注意到 $exp$ 里的求和 有一个 $F_k(x^n)$,因为 $F_k(x)$ 的前 $2^{k-1}$ 项一定与 $F_{k-1}(x)$ 相同,所以当 $n\ge 2$ 时,$\frac{F_k(x^n)}{n}$ 只与 $F_{k-1}(x)$ 有关,换句话说这个东西是常数,那么我们现在能够得到 $H(F_k(x))=F_k(x)-x\times exp(\sum_{n=2}\frac{F_k(x^n)}{n})\times exp F_k(x),H’(F_k(x))=1-x\times exp(\sum_{n=2}\frac{F_k(x^n)}{n})\times exp F_k(x)$,根据牛顿迭代的公式,我们能够得到 $F_k(x)=F_{k-1}(x)-\frac{F_k(x)-x\times exp(\sum_{n=2}\frac{F_k(x^n)}{n})\times exp F_k(x)}{1-x\times exp(\sum_{n=2}\frac{F_k(x^n)}{n})\times exp F_k(x)}$

时间复杂度为 $O(n\log n)$,常数较大

然后我们考虑无标号无根树计数,相对于无标号有根树计数,我们只需要在根为重心的时候统计一次即可,如果 $n$ 是奇数,那么树的重心有且仅有一个,我们枚举根的一个大于 $\lceil \frac{n}{2}\rceil$ 的子树的大小即可,不合法的答案为 $\sum_{i=\lceil\frac{n}{2}\rceil}^{n-1}f_{i}f_{n-i}$,如果 $n$ 是偶数,那么有可能存在有两个重心的情况,如果将两个重心中间的边断开,形成两棵树,这两棵树如果完全相同,则这些方案我们只会统计一次,不会算重,否则我们恰好算了两次,那么我们把多算的减掉即可,即减掉 $\binom{f_{\frac{n}{2}}}{2}$ 次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include <iostream>
#include <vector>
#include <algorithm>
#define maxn 200010
#define ll long long
using namespace std;

const int p = 998244353;

ll pow_mod(ll x, ll n) {
ll s = 1;
for (; n; n >>= 1, x = x * x % p)
if (n & 1) s = s * x % p;
return s;
}

#define Poly vector<int>
#define len(A) ((int) A.size())
namespace Pol {
inline int add(int a, int b) { return (a += b) >= p ? a -= p : a; }
inline int mul(int a, int b) { return 1ll * a * b % p; }
Poly operator - (const int &v, const Poly &a) {
Poly res(a);
for (int i = 0; i < len(res); ++i) res[i] = p - res[i];
res[0] = add(res[0], v); return res;
}
Poly operator - (const Poly &a, const int &v) {
Poly res(a); res[0] = add(res[0], p - v); return res;
}
Poly operator * (const Poly &a, const int &v) {
Poly res(a);
for (int i = 0; i < len(res); ++i) res[i] = mul(res[i], v);
return res;
}

const int N = 4200000;
int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void NTT(Poly &a, int type) {
static int w[N]; ll G = 3, Gi = pow_mod(G, p - 2); int n = len(a);
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
ll wn = pow_mod(type > 0 ? G : Gi, (p - 1) / i);
w[0] = 1; for (int j = 1; j < m; ++j) w[j] = wn * w[j - 1] % p;
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
int t1 = a[j + k], t2 = 1ll * a[j + k + m] * w[k] % p;
a[j + k] = add(t1, t2);
a[j + k + m] = add(t1, p - t2);
}
}
if (type < 0) {
int inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], inv);
}
}
Poly operator * (const Poly &A, const Poly &B) {
int n = 1, n1 = len(A), n2 = len(B); while (n < n1 + n2 - 1) n <<= 1; init_P(n);
Poly a(n), b(n);
for (int i = 0; i < n1; ++i) a[i] = add(A[i], p);
for (int i = 0; i < n2; ++i) b[i] = add(B[i], p);
NTT(a, 1); NTT(b, 1);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
NTT(a, -1); return a;
}
Poly Der(const Poly &a) {
Poly res(a);
for (int i = 0; i < len(a) - 1; ++i) res[i] = mul(i + 1, res[i + 1]);
res[len(a) - 1] = 0; return res;
}
Poly Int(const Poly &a) {
static int inv[N];
if (!inv[1]) {
inv[1] = 1;
for (int i = 2; i < N; ++i) inv[i] = mul(p - p / i, inv[p % i]);
}
Poly res(a); res.resize(len(a) + 1);
for (int i = len(a); i; --i) res[i] = mul(res[i - 1], inv[i]);
res[0] = 0; return res;
}
Poly Inv(const Poly &a) {
Poly res(1, pow_mod(a[0], p - 2));
int n = 1; while (n < len(a)) n <<= 1;
for (int k = 2; k <= n; k <<= 1) {
int L = 2 * k; init_P(L); Poly t(L);
copy_n(a.begin(), min(k, len(a)), t.begin());
t.resize(L); res.resize(L);
NTT(res, 1); NTT(t, 1);
for (int i = 0; i < L; ++i) res[i] = mul(res[i], add(2, p - mul(t[i], res[i])));
NTT(res, -1); res.resize(k);
} res.resize(len(a)); return res;
}
pair<Poly, Poly> Divide(const Poly &a, const Poly &b) {
int n = len(a), m = len(b);
Poly t1(a.rbegin(), a.rbegin() + n - m + 1), t2(b.rbegin(), b.rend()); t2.resize(n - m + 1);
Poly Q = Inv(t2) * t1; Q.resize(n - m + 1); reverse(Q.begin(), Q.end());
Poly R = Q * b; R.resize(m - 1); for (int i = 0; i < len(R); ++i) R[i] = add(a[i], p - R[i]);
return make_pair(Q, R);
}
Poly Ln(const Poly &a) {
Poly res = Int(Der(a) * Inv(a));
res.resize(len(a)); return res;
}
Poly Exp(const Poly &a) {
Poly res(1, 1);
int n = 1; while (n < len(a)) n <<= 1;
for (int k = 2; k <= n; k <<= 1) {
Poly t(res.begin(), res.end()); t.resize(k); t = Ln(t);
for (int i = 0; i < min(len(a), k); ++i) t[i] = add(a[i], p - t[i]); t[0] = add(t[0], 1);
res = res * t; res.resize(k);
} res.resize(len(a)); return res;
}
Poly Sqrt(const Poly &a) { // a[0] = 1
Poly res(1, 1); ll inv2 = pow_mod(2, p - 2);
int n = 1; while (n < len(a)) n <<= 1;
for (int k = 2; k <= n; k <<= 1) {
Poly t(res.begin(), res.end()), ta(a.begin(), a.begin() + min(len(a), k));
t.resize(k); t = Inv(t) * ta;
res.resize(k); for (int i = 0; i < k; ++i) res[i] = mul(add(res[i], t[i]), inv2);
} res.resize(len(a)); return res;
}
Poly Pow(const Poly &a, int k) { // a[0] = 1
return Exp(Ln(a) * k);
}
Poly Pow(const Poly &a, int k, int kk) {
int n = len(a), t = n, m, v, inv, powv; Poly res(n);
for (int i = n - 1; ~i; --i) if (a[i]) t = i, v = a[i];
if (k && t >= (n + k - 1) / k) return res;
if (t == n) { if (!k) res[0] = 1; return res; }
m = n - t * k; res.resize(m);
inv = pow_mod(v, p - 2); powv = pow_mod(v, kk);
for (int i = 0; i < m; ++i) res[i] = mul(a[i + (k > 0) * t], inv);
res = Exp(Ln(res) * k); res.resize(n);
for (int i = m - 1; ~i; --i) {
ll tmp = mul(res[i], powv);
res[i] = 0, res[i + t * k] = tmp;
}
return res;
}

Poly Newton(int lena) {
Poly res(1, 0);
int n = 1; while (n < lena) n <<= 1;
for (int k = 2; k <= n; k <<= 1) {
Poly t1(res), t2; t1.resize(k);
for (int i = 2; i < k; ++i) {
int inv = pow_mod(i, p - 2);
for (int j = 1; i * j < k; ++j) t1[i * j] = add(t1[i * j], mul(res[j], inv));
} t1 = Exp(t1); for (int i = k - 1; i; --i) t1[i] = t1[i - 1]; t1[0] = 0;
t2 = Inv(1 - t1); t1 = 0 - t1;
for (int i = 0; i < k / 2; ++i) t1[i] = add(t1[i], res[i]);
t1 = t1 * t2; t1.resize(k); t1 = 0 - t1;
for (int i = 0; i < k / 2; ++i) t1[i] = add(t1[i], res[i]); res = t1;
} res.resize(lena); return res;
}
} // namespace Pol

int n;

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n; Poly res = Pol::Newton(n + 1);
ll ans = res[n];
for (int i = n / 2 + 1; i < n; ++i) ans = (ans - 1ll * res[i] * res[n - i]) % p;
if (n % 2 == 0) ans = (ans - 1ll * res[n / 2] * (res[n / 2] - 1) % p * pow_mod(2, p - 2)) % p;
cout << (ans + p) % p << "\n";
return 0;
}