CF 1609F Interesting Sections

题目描述

https://codeforces.com/contest/1609/problem/F

简要题意:给定一个长度为 $n$ 的序列,求有多少区间 $[l,r]$ 满足,区间 $[l,r]$ 的最小值 $mn$ 和最大值 $mx$ 的二进制表示 $1$ 的个数相同

$n\le 10^6, a_i\le 10^{18}$

Solution

我们考虑经典的最小值最大值分治,枚举左端点,维护左边区间的后缀 $max$ 和 $min$,同时维护两个位置 $pm$ 和 $pn$ 表示左边当前位置的后缀 $max$ 和 $min$ 在右边区间仍然是最大值和最小值的最靠右的位置

对于右端点在 $[m+1,\min\lbrace pn,pm\rbrace]$,区间最大值和最小值都在左边,直接判断 $max$ 和 $min$ 的位数是否相等即可

右端点在 $[\min\lbrace pn,pm\rbrace+1,\max\lbrace pn,pm\rbrace]$,区间最大值或者最小值在左边,另一个在右边,我们相当于要统计 $[\min\lbrace pn,pm\rbrace+1,\max\lbrace pn,pm\rbrace]$ 内二进制表示 $1$ 的个数为定值的个数,这个东西如果我们预处理前缀和的话复杂度是 $O(n\log ^2 n)$,但如果我们离线下来做一个差分,这样就能做到 $O(n\log n)$

右端点在 $[\max\lbrace pn,pm\rbrace +1,r]$,可以预处理

时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <vector>
#define maxn 1000010
#define ll long long
using namespace std;

int n;
ll a[maxn];

ll ans; int suf[maxn], sn[61], sm[61]; vector<pair<int, int>> A[maxn], B[maxn];
void solve(int l, int r) {
if (l == r) return ++ans, void();
int m = l + r >> 1; solve(l, m); solve(m + 1, r);
ll mn = 1e18, mx = 0, Min = 1e18, Max = 0;
for (int i = m + 1; i <= r; ++i) A[i].clear(), B[i].clear();
for (int i = m + 1; i <= r; ++i) {
mn = min(mn, a[i]);
mx = max(mx, a[i]);
suf[i] = __builtin_popcountll(mn) == __builtin_popcountll(mx);
} suf[r + 1] = 0;
for (int i = r; i >= m; --i) suf[i] += suf[i + 1];
mn = 1e18; mx = 0;
for (int i = m, pn = m, pm = m, cm, cn; i >= l; --i) {
mn = min(mn, a[i]); cn = __builtin_popcountll(mn);
mx = max(mx, a[i]); cm = __builtin_popcountll(mx);
while (pn < r && min(Min, a[pn + 1]) >= mn) Min = min(Min, a[++pn]);
while (pm < r && max(Max, a[pm + 1]) <= mx) Max = max(Max, a[++pm]);
if (cn == cm) ans += min(pn, pm) - m;
if (pn < pm) A[pn].emplace_back(cm, -1), A[pm].emplace_back(cm, 1);
else if (pm < pn) B[pm].emplace_back(cn, -1), B[pn].emplace_back(cn, 1);
ans += suf[max(pn, pm) + 1];
} mn = 1e18; mx = 0;
for (int i = 0; i <= 60; ++i) sm[i] = sn[i] = 0;
for (int i = m + 1; i <= r; ++i) {
mn = min(mn, a[i]);
mx = max(mx, a[i]);
sn[__builtin_popcountll(mn)]++; sm[__builtin_popcountll(mx)]++;
for (auto [k, opt] : A[i]) ans += opt * sn[k];
for (auto [k, opt] : B[i]) ans += opt * sm[k];
}
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i];
solve(1, n); cout << ans << "\n";
return 0;
}