CF 1625D Binary Spiders

题目描述

https://codeforces.com/contest/1625/problem/D

简要题意:给定一个长度为 $n$ 的数列和一个数字 $k$,求从这 $n$ 歌数中找出一个大小最大的子集,满足这个集合中任意两个数的异或和都大于等于 $k$

$n\le 3\times 10^5$

Solution

首先得到 $k$ 的最高二进制一的位置,令其为 $t$,那么我们容易发现如果如果两个数字 $t$ 之前的位置有所不同,那么这两个数字的异或和一定大于 $k$,注意到这是一个前缀的比较,那么我们可以联想到 $01Trie$

我们将这些点拉出来,发现他们内部选出来的点与外面的点是没有影响的,然后容易发现,这些点内部最多选两个点,左儿子一个,右儿子一个,因为左右儿子内部选出来的点异或最大不会超过 $2^{t}-1$,那么这里我们只需要枚举左儿子的点,然后对于每个点从右儿子中选一个最大的,在 $01Trie$ 上实现这点是非常容易的,最后选择异或和最大的一对点即可

时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <iostream>
#include <vector>
#define maxn 300010
#define INF 1000000000
#define ll long long
using namespace std;

int n, k, d, a[maxn];

#define lc T[i].ch[0]
#define rc T[i].ch[1]
const int N = 29;
struct Trie {
int v, ch[2];
} T[maxn * 31]; int rt = 1, top = 1;
void insert(int k, int v) {
int i = rt;
for (int o = N; ~o; --o) {
int d = k >> o & 1;
if (!T[i].ch[d]) T[i].ch[d] = ++top;
i = T[i].ch[d];
} T[i].v = v;
}

vector<int> vec;
void get(int i, int o) {
if (!i) return ;
if (o == -1) return vec.push_back(T[i].v);
get(lc, o - 1); get(rc, o - 1);
}

int get_max(int i, int k, int o) {
if (!i) return 0;
if (o == -1) return T[i].v;
int d = k >> o & 1;
if (T[i].ch[d ^ 1]) return get_max(T[i].ch[d ^ 1], k, o - 1);
else return get_max(T[i].ch[d], k, o - 1);
}

vector<int> Ans;
void dfs(int i, int o) {
if (!i) return ;
if (o > d) return dfs(lc, o - 1), dfs(rc, o - 1);
vec.clear(); get(lc, o - 1); int ans = 0, ansu, ansv;
for (auto u : vec) {
int v = get_max(rc, a[u], o - 1); if (!v) continue;
if ((a[u] ^ a[v]) > ans) ans = a[u] ^ a[v], ansu = u, ansv = v;
}
if (ans >= k) Ans.push_back(ansu), Ans.push_back(ansv);
else {
if (vec.size()) Ans.push_back(vec.back());
else Ans.push_back(get_max(rc, 0, o - 1));
}
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> k;
for (int i = 1; i <= n; ++i) cin >> a[i], insert(a[i], i);
if (k == 0) {
cout << n << "\n";
for (int i = 1; i <= n; ++i) cout << i << " \n"[i == n];
return 0;
}
for (int i = 0; i <= N; ++i)
if (k >> i & 1) d = i;
dfs(rt, N);
if (Ans.size() <= 1) cout << -1 << "\n";
else {
cout << Ans.size() << "\n";
for (auto t : Ans) cout << t << " "; cout << "\n";
}
return 0;
}