Luogu P4173 残缺的字符串

题目描述

https://www.luogu.com.cn/problem/P4173

Solution

这是一个很经典的做法

我们考虑构造一个多项式,使得如果 $S_i=T_i$,那么这个式子是 $0$,其它情况下均为一个大于 $0$ 的数字,另外还要能处理通配符

那么我们能够得到这么一个式子 $(S_i-T_i)^2\times S_i\times T_i$

然后我们考虑怎么对于 $S$ 的每个 $pre_i$ 都按照这个式子求一下

我们考虑首先将 $a$ 到 $z$ 映射到 $1$ 到 $26$,通配符为 $0$

然后把 $T$ 翻转一下,然后我们直接将 $S$ 和 $T$ 卷积起来,卷积得到的 $A_i$ 就是 $pre_i$ 的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <iostream>
#include <vector>
#include <algorithm>
#define maxn 1050000
#define ll long long
using namespace std;

const int p = 998244353;
const int G = 3;
const int Gi = 332748118;

ll pow_mod(ll x, ll n) {
ll s = 1;
for (; n; n >>= 1) {
if (n & 1) s = s * x % p;
x = x * x % p;
}
return s;
}

int n, m, s[maxn], t[maxn], f[maxn], A[maxn], B[maxn];

char S[maxn], T[maxn];

int a[maxn], b[maxn];
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }

int P[maxn];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}

void NTT(int *a, int n, int type) {
static int w[maxn];
for (int i = 0; i < n; ++i) if (i < P[i]) swap(a[i], a[P[i]]);
for (int i = 2, m = 1; i <= n; m = i, i *= 2) {
ll wn = pow_mod(type > 0 ? G : Gi, (p - 1) / i);
w[0] = 1; for (int j = 1; j < m; ++j) w[j] = wn * w[j - 1] % p;
for (int j = 0; j < n; j += i)
for (int k = 0; k < m; ++k) {
ll t1 = a[j + k], t2 = 1ll * a[j + k + m] * w[k] % p;
a[j + k] = add(t1, t2);
a[j + k + m] = add(t1, p - t2);
}
}
if (type < 0) {
ll inv = pow_mod(n, p - 2);
for (int i = 0; i < n; ++i) a[i] = inv * a[i] % p;
}
}

void Mul(int *a, int *B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n2; ++i) b[i] = B[i];
fill(a + n1, a + n, 0); fill(b + n2, b + n, 0);
NTT(a, n, 1); NTT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % p;
NTT(a, n, -1);
}

void solve() {
for (int i = 0; i < n; ++i) A[i] = s[i] * s[i] * s[i];
for (int i = 0; i < m; ++i) B[i] = t[i];
Mul(A, B, n, m);
for (int i = 0; i < n; ++i) f[i] = A[i];

for (int i = 0; i < n; ++i) A[i] = s[i];
for (int i = 0; i < m; ++i) B[i] = t[i] * t[i] * t[i];
Mul(A, B, n, m);
for (int i = 0; i < n; ++i) f[i] = (f[i] + A[i]) % p;

for (int i = 0; i < n; ++i) A[i] = s[i] * s[i];
for (int i = 0; i < m; ++i) B[i] = t[i] * t[i];
Mul(A, B, n, m);
for (int i = 0; i < n; ++i) f[i] = (f[i] - 2ll * A[i]) % p;

vector<int> ans;
for (int i = m - 1; i < n; ++i)
if ((f[i] + p) % p == 0) ans.push_back(i - m + 2);
cout << ans.size() << "\n";
for (int t : ans) cout << t << " ";
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> m >> n >> T >> S;
for (int i = 0; i < m; ++i)
if (T[i] == '*') t[i] = 0;
else t[i] = T[i] - 'a' + 1;
for (int i = 0; i < n; ++i)
if (S[i] == '*') s[i] = 0;
else s[i] = S[i] - 'a' + 1;
reverse(t, t + m); solve();
return 0;
}