Luogu P5598 【XR-4】混乱度

题目描述

https://www.luogu.com.cn/problem/P5598

简要题意:给定一个长度为 $n$ 的序列 $a_i$,求 $\sum_{i=1}^n\sum_{j=i}^nf_{i,j}$,其中 $f_{i,j}=\binom{\sum_{k=i}^ja_i}{a_i,a_{i+1},\cdots,a_j}\bmod p$

$n\le 5\times 10^5,a_i\le 10^{18},p\in \lbrace 2,3,5,7\rbrace$

Solution

容易得到 $f_{l,r}=f_{l,r-1}\times \binom{a_r+\sum_{i=l}^{r-1}a_i}{a_r}$,我们考虑枚举左端点,维护右端点递增的答案,每次相当于乘上一个组合数,这个组合数我们可以用 $lucas$ 定理计算,这样做的时间复杂度是 $O(n^2\log a_i)$

从另一个角度思考我们发现每次右端点递增的时候都是乘上一个组合数,那么有没有可能到某一个时候乘出来 $0$,这样就显然没必要继续乘了

我们深入考虑一下什么情况下会变成 $0$,考察 $kummer$ 定理,我们知道 $\binom{n+m}{m}$ 所含有 $p$ 的幂次等价于 $n$ 和 $m$ 在 $p$ 进制加法下的进位次数,那么我们知道这个组合数变为 $0$ 的条件是前缀和发生进位,那么我们如果变成 $0$ 就 $break$ 的时间复杂度就变成了 $O(np\log^2 a_i)$,因为对于一个左端点,最多只会乘 $p\log a_i$ 次

但是这样仍然不能通过此题,容易发现每次都做一遍 $lucas$ 太慢了,我们发现只需要每次把不是 $0$ 的位拿出来单独计算即可,这样时间复杂度就变成了 $O(np\log a_i)$,可以通过此题,另外对于 $a_i=0$ 的位置,我们只需要记录一个 $nxt_i$ 表示下一个不为 $0$ 的位置即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <iostream>
#include <vector>
#define maxn 500010
#define maxp 60
#define ll long long
using namespace std;

const int N = 60;

int n, p, nxt[maxn];
ll a[maxn];
vector<pair<int, int>> d[maxn];

inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }
inline int mul(int x, int y) { return x * y % p; }

int C[10][10];
void init_C(int n) {
for (int i = 0; i <= n; ++i) C[i][0] = 1;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= i; ++j) C[i][j] = add(C[i - 1][j - 1], C[i - 1][j]);
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> p; init_C(p);
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 1; i <= n; ++i) {
ll t = a[i];
for (int o = 0; o < N; ++o) {
int v = t % p; t /= p;
if (v) d[i].emplace_back(o, v);
}
} ll ans = 0; nxt[n] = n + 1;
for (int i = n - 1; i; --i) nxt[i] = a[i + 1] ? i + 1 : nxt[i + 1];
for (int l = 1; l <= n; ++l) {
int res = 1; vector<int> s(N);
for (int r = l; r <= n; r = nxt[r]) {
for (auto [k, v] : d[r]) {
if (s[k] + v >= p) { res = 0; break; }
s[k] += v; res = res * C[s[k]][v] % p;
}
if (!res) break; ans += (nxt[r] - r) * res;
}
} cout << ans << "\n";
return 0;
}