CF 452E Three strings

题目描述

http://codeforces.com/problemset/problem/452/E

简要题意:给定三个字符串 $A,B,C$,求对于所有 $l$,有多少个数对 $(a,b,c)$,满足 $A[a\cdots a+l-1]=B[b\cdots b+l-1]=C[c\cdots c+l-1]$

$|A|+|B|+|C|\le 3\times 10^5$

Solution

我们考虑后缀数组,将这三个字符串用特殊字符连起来后,我们枚举 $l$,将 $H_i\ge l$ 的 $i$ 和 $i-1$ 连起来,然后对于一个连续区间,其贡献就是属于第一个字符串的数量乘上第二个字符串的数量乘上第三个字符串的数量,我们发现倒叙枚举 $l$ 只有合并,可以用并查集维护

时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
#define maxn 300010
#define ll long long
using namespace std;

const int p = 1000000007;

int n, m, k, len;
char c[maxn], s[maxn];

int tax[maxn], tp[maxn], sa[maxn], rk[maxn], M = 255;
void rsort(int n) {
for (int i = 0; i <= M; ++i) tax[i] = 0;
for (int i = 1; i <= n; ++i) ++tax[rk[i]];
for (int i = 1; i <= M; ++i) tax[i] += tax[i - 1];
for (int i = n; i; --i) sa[tax[rk[tp[i]]]--] = tp[i];
}

int H[maxn];
void init_sa(int n) {
if (n == 1) return sa[1] = rk[1] = 1, void(); int cnt = 1;
for (int i = 1; i <= n; ++i) rk[i] = s[i], tp[i] = i; rsort(n);
for (int k = 1; k < n; k *= 2) {
if (cnt == n) break; M = cnt; cnt = 0;
for (int i = n - k + 1; i <= n; ++i) tp[++cnt] = i;
for (int i = 1; i <= n; ++i) if (sa[i] > k) tp[++cnt] = sa[i] - k;
rsort(n); swap(rk, tp); rk[sa[1]] = cnt = 1;
for (int i = 2; i <= n; ++i) {
if (tp[sa[i - 1]] != tp[sa[i]] || tp[sa[i - 1] + k] != tp[sa[i] + k]) ++cnt;
rk[sa[i]] = cnt;
}
} int lcp = 0;
for (int i = 1; i <= n; ++i) {
if (lcp) --lcp;
int j = sa[rk[i] - 1];
while (s[j + lcp] == s[i + lcp]) ++lcp;
H[rk[i]] = lcp;
}
}

int bl[maxn];
vector<int> A[maxn];

int fa[maxn], f[maxn][4], g[maxn];
void init_fa(int n) { for (int i = 1; i <= n; ++i) fa[i] = i, f[i][bl[i]] = 1; }

int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }

ll res;
inline void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy) return ;
fa[fy] = fx; res = (res - g[fx] - g[fy]) % p;
for (int i = 1; i <= 3; ++i) f[fx][i] += f[fy][i];
g[fx] = 1ll * f[fx][1] * f[fx][2] % p * f[fx][3] % p;
res = (res + g[fx]) % p;
}

int ans[maxn];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> c; n = strlen(c);
cin >> c + n; m = strlen(c + n);
cin >> c + n + m; k = strlen(c + n + m);
for (int i = 0; i < n; ++i) s[++len] = c[i], bl[len] = 1; s[++len] = '$';
for (int i = n; i < n + m; ++i) s[++len] = c[i], bl[len] = 2; s[++len] = '#';
for (int i = n + m; i < n + m + k; ++i) s[++len] = c[i], bl[len] = 3; init_sa(len);
for (int i = 1; i <= len; ++i) A[H[i]].push_back(i); init_fa(len);
for (int l = len; l; --l) {
for (auto t : A[l]) merge(sa[t - 1], sa[t]);
ans[l] = res;
}
for (int i = 1; i <= min({ n, m, k }); ++i) cout << (ans[i] + p) % p << " ";
return 0;
}