某场模拟赛-B(容斥)

题目描述

简要题意:给定一个 $n\times m$ 矩阵,矩阵中每个点可以染成 $k$ 种颜色中的某一种,求没有任意一行或任意一列颜色相同的方案数

$n,m\le 10^6,k\le 10^9$

Solution

我们令 $g_{r,c}$ 表示恰好有 $r$ 行和 $c$ 列颜色相同,答案即为 $g_{0,0}$

我们考虑求 $f_{r,c}=\sum_{i=r}^n\sum_{j=c}^m\binom{i}{r}\binom{j}{c}g_{i,j}\Leftrightarrow g_{r,c}=\sum_{i=r}^n\sum_{j=c}^m(-1)^{i-r}\binom{i}{r}(-1)^{j-c}\binom{j}{c}f_{i,j}$,$f_{r,c}$ 求法如下

1
2
3
4
if(!r && !c) f[r][c] = pow_mod(k, n * m);
else if(r > 0 && !c) f[r][c] = C(n, r) * pow_mod(k, r) * pow_mod(k, (n - r) * m);
else if(c > 0 && !r) f[r][c] = C(m, c) * pow_mod(k, c) * pow_mod(k, n * (m - c));
else f[r][c] = C(n, r) * C(m, c) * k * pow_mod(k, (n - r) * (m - c)); //行和列有重合,所以必须是一个颜色

我们分情况讨论一下,得到下面这个式子

$g_{0,0}=k^{nm}+\sum_{i=1}^n(-1)^i\binom{n}{i}k^{m(n-i)+i}+\sum_{i=1}^m(-1)^i\binom{m}{i}k^{n(m-i)+i}+\sum_{i=1}^n\sum_{j=1}^m(-1)^{i+j}\binom{n}{i}\binom{m}{j}k^{(n-i)(m-j)}$

后面那两个求和,可以利用二项式定理化成 $\sum_{i=1}^n(-1)^i\binom{n}{i}[(k^{n-i}-1)^m-k^{(n-i)m}]$

然后就可以 $O(n\log n)$ 了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <cstdio>
#include <iomanip>
#define maxn 1000010
#define ll long long
using namespace std;

const int p = 998244353;

ll pow_mod(ll x, ll n) {
ll s = 1;
for (; n; n >>= 1) {
if (n & 1) s = s * x % p;
x = x * x % p;
}
return s;
}

ll n, k, m;

ll fac[maxn], inv[maxn];
void init_C(int n) {
fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % p;
inv[n] = pow_mod(fac[n], p - 2); for (int i = n - 1; ~i; --i) inv[i] = inv[i + 1] * (i + 1) % p;
}

ll C(int n, int m) { return n < m ? 0 : fac[n] * inv[m] % p * inv[n - m] % p; }

ll ans;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m >> k; init_C(max(n, m));
ans = pow_mod(k, n * m);
for (int i = 1, t = -1; i <= n; ++i, t = -t)
ans = (ans + t * C(n, i) * pow_mod(k, i) % p * pow_mod(k, (n - i) * m)) % p;
for (int i = 1, t = -1; i <= m; ++i, t = -t)
ans = (ans + t * C(m, i) * pow_mod(k, i) % p * pow_mod(k, n * (m - i))) % p;
for (int i = 1, t = -1; i <= n; ++i, t = -t)
ans = (ans + t * k * C(n, i) % p * (pow_mod(pow_mod(k, n - i) - 1, m) - pow_mod(k, (n - i) * m))) % p;
cout << (ans + p) % p << "\n";
return 0;
}