2022杭电多校8 M Shattrath City

题目描述

https://acm.hdu.edu.cn/showproblem.php?pid=7232

简要题意:给定 $n,m$,求有多少长度为 $m$ 的序列 $a_i$,满足 $a_i\in[1,n]$,且不存在一个长度为 $n$ 的子区间且这个子区间是一个 $1$ 到 $n$ 的排列

$n,m\le 2\times 10^5$

Solution

我们考虑计算出现 $[1,n]$ 的排列的方案数,对于每个出现过 $[1,n]$ 的序列 $a_i$,我们在 $[1,n]$ 第一次出现的位置来统计贡献,我们令 $f_k$ 表示区间 $[1,k-1]$ 没有出现 $[1,n]$ 的排列,且 $[k,k+n-1]$ 是一个 $[1,n]$ 的排列的方案数,容易得到 $f_k=n^{k-1}n!-\sum_{i=1}^{i+n-1<k}f_in^{k-i-n}n!-\sum_{i+n-1\ge k}^{k-1}f_i(k-i)!$,容易发现这个东西是符合分治 $NTT$ 的形式的,直接计算即可,时间复杂度 $O(n\log^2 n)$,另外这个式子也可以化简成求逆的式子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <iostream>
#include <vector>
#include <algorithm>
#define maxn 200010
#define ll long long
using namespace std;

const int p = 998244353;
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }
inline int mul(int x, int y) { return 1ll * x * y % p; }
inline int add(initializer_list<int> lst) { int s = 0; for (auto t : lst) s = add(s, t); return s; }
inline int mul(initializer_list<int> lst) { int s = 1; for (auto t : lst) s = mul(s, t); return s; }
int pow_mod(int x, ll n) {
int s = 1;
for (; n; n >>= 1, x = mul(x, x))
if (n & 1) s = mul(s, x);
return s;
}

typedef std::vector<int> Poly;
namespace Pol {
const int N = 4200000;
int a[N], b[N];

const int G = 3;
const int Gi = pow_mod(G, p - 2);

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<int> init_W(int n) {
vector<int> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
int wn = pow_mod(G, (p - 1) / (i << 1));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = mul(w1[j], wn);
}
return w;
} auto w = init_W(1 << 21);
void DIT(int *a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = a[i + j + k];
a[i + j + k] = mul(add(x, p - y), w[k + j]), a[i + j] = add(x, y);
}
}
void DIF(int *a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int x = a[i + j], y = mul(a[i + j + k], w[k + j]);
a[i + j + k] = add(x, p - y), a[i + j] = add(x, y);
}
int inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], inv);
reverse(a + 1, a + n);
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = add(A[i], p);
for (int i = 0; i < n2; ++i) b[i] = add(B[i], p);
fill(a + n1, a + n, 0), fill(b + n2, b + n, 0);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(n1 + n2 - 1); copy_n(a, n1 + n2 - 1, ans.begin());
return ans;
}
Poly MMul(const Poly &A, const Poly &B, int len) { //减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = add(A[i], p);
for (int i = 0; i < len; ++i) b[i] = add(B[i], p);
fill(a + len, a + n, 0), fill(b + len, b + n, 0); reverse(b, b + len);
DIT(a, n); DIT(b, n);
for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]);
DIF(a, n); Poly ans(len);
copy_n(a, len, ans.begin()); reverse(ans.begin(), ans.end());
return ans;
}
} // namespace Pol

int fac[maxn], inv[maxn];
void init_C(int n) {
fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = mul(fac[i - 1], i);
inv[n] = pow_mod(fac[n], p - 2); for (int i = n - 1; ~i; --i) inv[i] = mul(inv[i + 1], i + 1);
}

int n, m;

int pp[maxn], invpp[maxn], pre[maxn];
void solve(Poly &A, const Poly &B, int l, int r) { // 区间左闭右开
if (l + 1 == r) {
if (!l) A[l] = 0;
else {
A[l] = add(mul(pp[l - 1], fac[n]), p - A[l]);
if (l > n) A[l] = add(A[l], p - mul({ pre[l - n], pp[l - n], fac[n] }));
pre[l] = add(pre[l - 1], mul(A[l], invpp[l]));
} return ;
} int m = l + r >> 1; solve(A, B, l, m);
Poly t1(A.begin() + l, A.begin() + m), t2(B.begin(), B.begin() + r - l);
Poly t = Pol::Mul(t1, t2, m - l, r - l);
for (int i = m; i < r; ++i) A[i] = add(A[i], t[i - l]);
solve(A, B, m, r);
}

void work() {
cin >> n >> m;
pp[0] = invpp[0] = 1;
for (int i = 1, inv = pow_mod(n, p - 2); i <= m; ++i) {
pp[i] = mul(pp[i - 1], n);
invpp[i] = mul(invpp[i - 1], inv);
}
Poly A(m + 1), B(m + 1);
for (int i = 0; i < n; ++i) B[i] = fac[i];
solve(A, B, 0, m - n + 2); int ans = pp[m];
for (int i = 1; i <= m - n + 1; ++i) ans = add(ans, p - mul(A[i], pp[m - n + 1 - i]));
cout << ans << "\n";
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

int T; cin >> T; init_C(200000);
while (T--) work();
return 0;
}