校内赛 T3 数学分析

题目描述

简要题意:给定一个长度为 $n$ 的串 $S$,定义一个长度为 $k$ 区间序列 $[l_1,r_1],\cdots,[l_k,r_k]$ 为一个合法区间序列,当且仅当满足于 $l_1<l_2<\cdots<l_k\le r_k<r_{m-1}<\cdots <r_1$,求有多少区间套序列满足,$\forall i\in[1,m],S[l_i..r_i]$ 是回文串

$n\le 10^6$

Solution

我们考虑对于所有右端点来计数,那么为了求所有以 $r$ 结尾的回文串,我们可以想到回文自动机,我们定义 $f_{u,0/1/2}$ 分别表示最后一个选择的区间是 $u$ 这个区间;最后一个选择的区间是 $u$ 的 $border$;最后一个选择的区间是 $u$ 的子串,能够发现我们可以用通过字符边连向 $u$ 的节点 $v$ 和 $u$ 的 $fail$ 边这两个节点的状态来进行转移

时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
 #include <iostream>
#include <cstring>
#define maxn 1000010
#define ll long long
using namespace std;

const int p = 998244353;
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }
inline int mul(int x, int y) { return 1ll * x * y % p; }
inline int add(initializer_list<int> lst) { int s = 0; for (auto t : lst) s = add(s, t); return s; }

int n;
char s[maxn];

struct PAM {
int l, nxt[26], fail;
} T[maxn]; int top, last;
void init_PAM() {
top = 1; last = 0;
T[0].l = 0; T[0].fail = 1;
T[1].l = -1; T[1].fail = -1;
}

inline int get(int i, int u) {
while (s[i - T[u].l - 1] != s[i]) u = T[u].fail;
return u;
}

int f[maxn][3], g[maxn];
int insert(int i, int ch) {
int p = get(i, last);
if (!T[p].nxt[ch]) {
int q = ++top;
T[q].l = T[p].l + 2;
T[q].fail = T[get(i, T[p].fail)].nxt[ch];
T[p].nxt[ch] = q;

f[q][0] = add(f[p][2], 1);
f[q][1] = add(f[T[q].fail][1], f[T[q].fail][0]);
f[q][2] = add({ f[p][2], mul(2, f[q][1]), f[q][0] });
g[q] = add(g[T[q].fail], f[q][0]);
} last = T[p].nxt[ch];
return last;
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> s + 1; n = strlen(s + 1); init_PAM(); int ans = 0;
for (int i = 1; i <= n; ++i) ans = add(ans, g[insert(i, s[i] - 'a')]);
cout << ans << "\n";
return 0;
}