Luogu P6216 回文匹配

题目描述

https://www.luogu.com.cn/problem/P6216

简要题意:给定字符串 $S$ 和 $T$,其长度分别为 $n$ 和 $m$,$S[l..r]$ 的价值为 $T$ 在 $S[l..r]$ 中的出现次数,求 $S$ 的所有长度为奇数的回文子串的价值和

$n,m\le 3\times 10^6$

Solution

首先用 $kmp$ 求出 $T$ 在 $S$ 的中所有匹配位置,我们以结束位置来标记

然后用 $manacher$ 求出每个回文中心 $i$ 的极大回文子串的半径 $a_i$,即以 $i$ 为中心的极大回文子串为 $S[i-a_i..i+a_i]$,那么对于一个 $T$ 的匹配位置 $r$,它对以 $i$ 为中心,半径为 $j$ 的产生贡献的条件为 $r\in [i-j+m-1,i+j]$,也就是回文中心 $i$ 的贡献为 $j\in[\lfloor\frac{m}{2}\rfloor,a_i],d_{i+j}-d_{i+m-1-j-1}$,其中 $d$ 是 $T$ 的匹配位置的一阶前缀和数组,容易发现这个贡献的形式显然为 $d$ 的区间查询,我们把 $d$ 再做一次前缀和即可实现对于每个回文中心 $i$,$O(1)$ 计算答案

时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <iostream>
#include <cstring>
#define maxn 3000010
#define ll long long
#define uint unsigned int
using namespace std;

int n, m;
char s[maxn], t[maxn];

int nxt[maxn];
void init_nxt(char *s) {
int k = 0, l = strlen(s + 1); nxt[1] = 0;
for (int i = 2; i <= l; ++i) {
while (k && s[k + 1] != s[i]) k = nxt[k];
if (s[k + 1] == s[i]) ++k;
nxt[i] = k;
}
}

uint d[maxn], dd[maxn];
void kmp(char *s, char *t) {
int k = 0, l1 = strlen(s + 1), l2 = strlen(t + 1);
for (int i = 1; i <= l1; ++i) {
while (k && s[i] != t[k + 1]) k = nxt[k];
if (t[k + 1] == s[i]) ++k;
if (k == l2) d[i] = 1;
d[i] += d[i - 1];
}
for (int i = 1; i <= l1; ++i) dd[i] += dd[i - 1] + d[i];
}

struct Manacher {
int n, l, f[maxn * 2], Len;
char s[maxn * 2];

void init(char *c) {
l = strlen(c + 1); s[0] = '~';
for (int i = 1, j = 2; i <= l; ++i, j += 2)
s[j] = c[i], s[j - 1] = '#';
n = 2 * l + 1; s[n] = '#'; s[n + 1] = '\0';
}
void manacher() {
int p = 0, mr = 0;
for (int i = 1; i <= n; ++i) f[i] = 0;
for (int i = 1; i <= n; ++i) {
if (i < mr) f[i] = min(f[2 * p - i], mr - i);
while (s[i + f[i]] == s[i - f[i]]) ++f[i]; --f[i];
if (f[i] + i > mr) mr = i + f[i], p = i;
Len = max(Len, f[i]);
}
}
} M;

int a[maxn];
uint solve() {
for (int i = 1; i <= 2 * n + 1; ++i) {
int L = i - M.f[i] + 1 >> 1, R = i + M.f[i] - 1 >> 1, len = R - L + 1;
if (!M.f[i] || ~len & 1) continue; a[L + len / 2] = len / 2;
} int l = m / 2; uint ans = 0;
for (int i = 1; i <= n; ++i) {
if (l > a[i]) continue;
ans += dd[i + a[i]] - dd[i + l - 1];
ans -= dd[i + m - l - 2] - (i + m - a[i] - 3 >= 0 ? dd[i + m - a[i] - 3] : 0);
//for (int j = l; j <= a[i]; ++j) ans += d[i + j] - d[i + m - 1 - j - 1];
} return ans;
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m >> s + 1 >> t + 1;
init_nxt(t); kmp(s, t); M.init(s), M.manacher();
cout << solve() << "\n";
return 0;
}