Luogu P3369 【模板】普通平衡树

题目描述

https://www.luogu.com.cn/problem/P3369

Solution

$ Splay$ 模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <iostream>
#include <cstdio>
#include <cctype>
#define maxn 100010
#define gc getchar
#define INF 2000000000
using namespace std;

int read() {
int x = 0, f = 0; char c = gc();
while (!isdigit(c)) f |= c == '-', c = gc();
while (isdigit(c)) x = x * 10 + c - '0', c = gc();
return f ? -x : x;
}

int n, m;

#define lc T[i].ch[0]
#define rc T[i].ch[1]
struct Splay {
int v, sz, ch[2], cnt;
} T[maxn]; int top, rt, f[maxn];

inline int get(int i) { return i == T[f[i]].ch[1]; }

inline void clear(int i) {
T[i].sz = T[i].v = lc = rc = T[i].cnt = f[i] = 0;
}

inline int newnode(int v) {
int i = ++top;
T[i].cnt = T[i].sz = 1; T[i].v = v;
return i;
}

inline void update(int i) { T[i].sz = T[i].cnt + T[lc].sz + T[rc].sz; }

inline void rotate(int x) {
int fa = f[x], ffa = f[f[x]], wx = get(x);
f[x] = ffa; f[T[x].ch[wx ^ 1]] = fa; f[fa] = x;
if (ffa) T[ffa].ch[T[ffa].ch[1] == fa] = x;
T[fa].ch[wx] = T[x].ch[wx ^ 1]; T[x].ch[wx ^ 1] = fa;
update(fa); update(x);
}

void Splay(int i, int k = 0) {
for (int fa = f[i]; fa != k; rotate(i), fa = f[i])
if (f[fa] != k) rotate(get(i) == get(fa) ? fa : i);
if (!k) rt = i;
}

void insert(int v) {
int i = rt, d;
if (!i) return (void) (rt = newnode(v));
while (1) {
d = v > T[i].v;
if (T[i].v == v) return ++T[i].cnt, Splay(i);
if (!T[i].ch[d]) {
T[i].ch[d] = newnode(v); f[T[i].ch[d]] = i;
Splay(T[i].ch[d]); return ;
} i = T[i].ch[d];
}
}

int findk(int k) {
int i = rt;
while (1) {
if (k <= T[lc].sz) { i = lc; continue; }
k -= T[lc].sz;
if (k <= T[i].cnt) return Splay(i), i;
k -= T[i].cnt; i = rc;
}
}

void find(int v) { // 如果 v 存在,可以找到 v,否则找到大于 v 最小 或者 小于 v 最大
int i = rt;
while (T[i].ch[v > T[i].v] && v != T[i].v) i = T[i].ch[v > T[i].v];
Splay(i);
}

int findp(int v) {
find(v); int i = T[rt].ch[0];
if (T[rt].v < v) return rt;
while (rc) i = rc; Splay(i); return i;
}

int findn(int v) {
find(v); int i = T[rt].ch[1];
if (T[rt].v > v) return rt;
while (lc) i = lc; Splay(i); return i;
}

void del(int v) {
int pre = findp(v), nxt = findn(v), i;
Splay(pre); Splay(nxt, pre); i = T[nxt].ch[0];
if (T[i].cnt > 1) return --T[i].cnt, Splay(i);
T[nxt].ch[0] = 0; clear(i); update(nxt); update(pre);
}

inline void solve_1() {
int x = read();
insert(x);
}

inline void solve_2() {
int x = read();
del(x);
}

inline void solve_3() {
int x = read(), p; find(x);
if (T[rt].v == x) printf("%d\n", T[T[rt].ch[0]].sz);
else p = findp(x), printf("%d\n", T[T[p].ch[0]].sz + T[p].cnt);
}

inline void solve_4() {
int x = read();
printf("%d\n", T[findk(x + 1)].v);
}

inline void solve_5() {
int x = read();
printf("%d\n", T[findp(x)].v);
}

inline void solve_6() {
int x = read();
printf("%d\n", T[findn(x)].v);
}

int main() {
n = read(); insert(-INF); insert(INF);
for (int i = 1; i <= n; ++i) {
int opt = read();
switch (opt) {
case 1 : solve_1(); break;
case 2 : solve_2(); break;
case 3 : solve_3(); break;
case 4 : solve_4(); break;
case 5 : solve_5(); break;
case 6 : solve_6(); break;
}
}
return 0;
}