CF 1613F Tree Coloring

题目描述

https://codeforces.com/problemset/problem/1613/F

简要题意:给定一棵以 $1$ 为根 $n$ 个点的有根树,现在要为每个点赋一个权值 $p_i$,要求 $p_i$ 为 $[1,n]$ 的排列,且对于 $u$ 为 $v$ 的父亲 $p_u\neq p_v+1$,求方案数

$n\le 2.5\times 10^5$

Solution

我们考虑组合容斥,令 $f_k=x^k!\prod_{i=1}^n(1+d_ix)$,其中 $d_i$ 表示 $i$ 的儿子的数量,这个式子大概的意思就是选 $k$ 条不合法的边,注意到每个点 $i$ 和其儿子的连边中我们只能选择一条,令 $g_k$ 表示恰好有 $k$ 条不合法边的方案,容易得到 $f_k=\sum_{i=k}^n\binom{i}{k}g_i,g_k=\sum_{i=1}^k(-1)^{i-k}\binom{i}{k}f_i$,所求即为 $g_0$

对于 $f_k$,容易想到 $O(n\log^2 n)$ 的分治 $NTT$ 的做法,但是我们知道 $\sum_{i=1}^nd_i=n-1=O(n)$,我们考虑优化,首先对于相同的 $d_i$,不妨假设有 $k$ 个,我们显然可以做到 $O(k)$,那么我们将 $d_i$ 从大到小排序,对于相同的线性做,不同的暴力乘起来,这样的时间复杂度为 $O(n\log n)$,证明如下,令 $S$ 表示 $d_i$ 的可重集合,容易发现时间复杂度为 $O((\sum_{i\in S}\sum_{j=1}^n[i\ge j])\log n)=O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <iostream>
#include <vector>
#include <algorithm>
#define maxn 250010
using namespace std;

const int p = 998244353;
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }
inline int mul(int x, int y) { return 1ll * x * y % p; }
inline int add(initializer_list<int> lst) { int s = 0; for (auto t : lst) s = add(s, t); return s; }
inline int mul(initializer_list<int> lst) { int s = 1; for (auto t : lst) s = mul(s, t); return s; }
int pow_mod(int x, int n) {
int s = 1;
for (; n; n >>= 1, x = mul(x, x))
if (n & 1) s = mul(s, x);
return s;
}

typedef std::vector<int> Poly;
namespace Pol {
inline int add(int a, int b) { return (a += b) >= p ? a -= p : a; }
inline int mul(int a, int b) { return 1ll * a * b % p; }

const int N = 4200000;
int a[N], b[N];

const int G = 3;
const int Gi = pow_mod(G, p - 2);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<int> init_W(int n) {
vector<int> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
int wn = pow_mod(G, (p - 1) / (i << 1));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = mul(w1[j], wn);
}
return w;
} auto w = init_W(1 << 21);
void DIT(int *a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = a[i + j + k];
a[i + j + k] = mul(add(x, p - y), w[k + j]), a[i + j] = add(x, y);
}
}
void DIF(int *a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = mul(a[i + j + k], w[k + j]);
a[i + j + k] = add(x, p - y), a[i + j] = add(x, y);
}
int inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], inv);
reverse(a + 1, a + n);
}
Poly Mul(const Poly &A, const Poly &B) {
int n = 1, n1 = A.size(), n2 = B.size(); while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = add(A[i], p);
for (int i = 0; i < n2; ++i) b[i] = add(B[i], p);
fill(a + n1, a + n, 0), fill(b + n2, b + n, 0);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(n1 + n2 - 1); copy_n(a, n1 + n2 - 1, ans.begin());
return ans;
}
Poly MMul(const Poly &A, const Poly &B, int len) { //减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = add(A[i], p);
for (int i = 0; i < len; ++i) b[i] = add(B[i], p);
fill(a + len, a + n, 0), fill(b + len, b + n, 0); reverse(b, b + len);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(len);
copy_n(a, len, ans.begin()); reverse(ans.begin(), ans.end());
return ans;
}
} // namespace Pol

int n;
int du[maxn];

int fac[maxn];
void init_fac(int n) { fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = mul(fac[i - 1], i); }

int inv[maxn];
void init_inv(int n) { inv[1] = 1; for (int i = 2; i <= n; ++i) inv[i] = mul(p - p / i, inv[p % i]); }

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n; init_fac(n); init_inv(n);
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
++du[x]; ++du[y];
}
for (int i = 2; i <= n; ++i) --du[i];
sort(du + 1, du + n + 1, greater<int>());
Poly res{1};
for (int i = 1, p = 1; i <= n; i = p) {
while (p <= n && du[i] == du[p]) ++p;
int C = 1, ki = 1, k = du[i]; Poly A(p - i + 1); A[0] = 1;
for (int j = 1; j <= p - i; ++j) {
C = mul({ C, (p - i) - j + 1, inv[j] });
ki = mul(ki, k);
A[j] = mul(C, ki);
} res = Pol::Mul(res, A);
} int ans = 0;
for (int i = 0, opt = 1; i <= n; ++i, opt = p - opt)
ans = add(ans, mul({ opt, res[i], fac[n - i] }));
cout << ans << "\n";
return 0;
}