Luogu P4696 [CEOI2011] Matching

题目描述

简要题意:给定一个长度为 $n$ 的排列 $p_i$ 和一个长度为 $m$ 的序列 $a_i$,对于一个长度为 $n$ 的序列 $a_i$ 和一个长度为 $n$ 的排列 $p_i$,称 $a_i$ 符合 $p_i$ 当且仅当 $a_i$ 互不相同,且将 $a_i$ 从小到大排序会得到 $a_{p_1},\cdots, a_{p_n}$,现在需要对给定的序列 $a_i$ 求其有多少个符合给定排列 $p_i$ 的长度为 $m$ 的子串

$n\le m\le 10^6$

Solution

首先我们将 $p_i$ 变成 $p’_{p_i}=i$,下面以 $p$ 来代表 $p’$,那么现在的匹配规则就是将 $a_i$ 离散化后与 $p_i$ 是否对应位都相等,注意到这个东西是非常难处理的,因为一个位置的值与其前面的数和后面的数都有关系,我们考虑将每个数的值定义为其前面有多少小于它的数

然后我们考虑求 $nxt$ 数组和 $kmp$ 的过程,我们需要动态维护对于 $a_i$ 在序列 $[nxt_{i-1},i]$ 的值,这个可以用树状数组实现,在 $nxt$ 向会跳的时候我们在树状数组上撤销这些值即可,$nxt$ 指针的移动是均摊 $O(n)$ 的,匹配时用类似的方法即可

时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <iostream>
#include <algorithm>
#include <vector>
#define maxn 1000010
#define lowbit(i) ((i) & (-i))
using namespace std;

int n, m, s[maxn], t[maxn];

int Bit[maxn];
void add(int i, int v) { while (i <= n) Bit[i] += v, i += lowbit(i); }
int get_sum(int i) {
int s = 0;
while (i) s += Bit[i], i -= lowbit(i);
return s;
}
void clear(int l, int r, int *a) { for (int i = l; i <= r; ++i) add(a[i], -1); }

int b[maxn];
void init_hash(int *a) {
for (int i = 1; i <= n; ++i) b[i] = a[i];
sort(b + 1, b + n + 1); int cnt = unique(b + 1, b + n + 1) - b - 1;
for (int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + cnt + 1, a[i]) - b;
}

int nxt[maxn], cnt[maxn];
void init_nxt(int *s, int n) {
for (int i = 1; i <= n; ++i) cnt[i] = get_sum(s[i]), add(s[i], 1);
fill(Bit + 1, Bit + n + 1, 0); cnt[m + 1] = -1;
int k = 0; nxt[1] = 0;
for (int i = 2; i <= n; ++i) {
while (k && cnt[k + 1] != get_sum(s[i])) clear(i - 1 - k + 1, i - 1 - nxt[k], s), k = nxt[k];
if (cnt[k + 1] == get_sum(s[i])) ++k, add(s[i], 1);
nxt[i] = k;
}
}

void kmp(int *s, int n, int m) {
vector<int> ans; fill(Bit + 1, Bit + n + 1, 0); int k = 0;
for (int i = 1; i <= n; ++i) {
while (k && cnt[k + 1] != get_sum(s[i])) clear(i - 1 - k + 1, i - 1 - nxt[k], s), k = nxt[k];
if (cnt[k + 1] == get_sum(s[i])) ++k, add(s[i], 1);
if (k == m) ans.push_back(i - k + 1);
} cout << ans.size() << "\n";
for (auto t : ans) cout << t << " "; cout << "\n";
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> m >> n;
for (int i = 1, x; i <= m; ++i) cin >> x, t[x] = i;
for (int i = 1; i <= n; ++i) cin >> s[i]; init_hash(s);
for (int i = 1; i <= n; ++i) Bit[i] = 0; init_nxt(t, m); kmp(s, n, m);
return 0;
}