Luogu P4175 [CTSC2008]网络管理

题目描述

https://www.luogu.com.cn/problem/P4175

Solution

树上带修区间第 $k$ 小

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#define maxn 80010
#define logn 18
#define pb push_back
#define lowbit(i) ((i) & (-i))
using namespace std;

int n, m, a[maxn];

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

struct Query {
int opt, id, x, y, z;
} Q[maxn];

int b[maxn * 2], cnt;
void init_hash() {
for (int i = 1; i <= n; ++i) b[i] = a[i]; int c1 = n;
for (int i = 1; i <= m; ++i)
if (!Q[i].opt) b[++c1] = Q[i].y;
sort(b + 1, b + c1 + 1); cnt = unique(b + 1, b + c1 + 1) - b - 1;
for (int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + cnt + 1, a[i]) - b;
for (int i = 1; i <= m; ++i)
if (!Q[i].opt) Q[i].y = lower_bound(b + 1, b + cnt + 1, Q[i].y) - b;
}

#define lc T[i].ch[0]
#define rc T[i].ch[1]
#define Lc T[j].ch[0]
#define Rc T[j].ch[1]
struct zhuxi {
int v, ch[2];
} T[3 * maxn * logn * logn]; int rt[maxn * 2], top;
void update(int &i, int j, int l, int r, int k, int v) {
i = ++top; T[i] = T[j]; T[i].v += v;
if (l == r) return ; int m = l + r >> 1;
if (k <= m) update(lc, Lc, l, m, k, v);
else update(rc, Rc, m + 1, r, k, v);
}

vector<int> a1, a2, b1, b2;
int query(int l, int r, int k) {
int v = 0, m = l + r >> 1; if (l == r) return l;
for (auto u : a1) v += T[T[u].ch[1]].v;
for (auto u : a2) v += T[T[u].ch[1]].v;
for (auto u : b1) v -= T[T[u].ch[1]].v;
for (auto u : b2) v -= T[T[u].ch[1]].v;
if (k <= v) {
for (auto &u : a1) u = T[u].ch[1];
for (auto &u : a2) u = T[u].ch[1];
for (auto &u : b1) u = T[u].ch[1];
for (auto &u : b2) u = T[u].ch[1];
return query(m + 1, r, k);

}
else {
for (auto &u : a1) u = T[u].ch[0];
for (auto &u : a2) u = T[u].ch[0];
for (auto &u : b1) u = T[u].ch[0];
for (auto &u : b2) u = T[u].ch[0];
return query(l, m, k - v);
}
}

void add(int i, int k, int v) {
while (i <= n) update(rt[i], rt[i], 1, cnt, k, v), i += lowbit(i);
}

int dep[maxn], f[maxn][21], in[maxn], out[maxn];
void dfs(int u, int fa) {
static int c1 = 0; in[u] = ++c1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dep[v] = dep[u] + 1; f[v][0] = u; dfs(v, u);
} out[u] = c1;
}

void init_lca() {
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i) f[i][j] = f[f[i][j - 1]][j - 1];
}

int get_lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; ~i; --i)
if (dep[f[x][i]] >= dep[y]) x = f[x][i];
if (x == y) return x;
for (int i = 20; ~i; --i)
if (f[x][i] ^ f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}

inline void solve_0(int i) {
int x = Q[i].x, y = Q[i].y;
add(in[x], a[x], -1); add(out[x] + 1, a[x], 1);
a[x] = y;
add(in[x], a[x], 1); add(out[x] + 1, a[x], -1);
}

inline void solve_1(int i) {
int l = Q[i].x, r = Q[i].y, k = Q[i].z, p = get_lca(l, r), fp = f[p][0];
if (k > dep[l] + dep[r] - dep[p] - dep[fp]) return (void) puts("invalid request!");
a1.clear(); a2.clear(); b1.clear(); b2.clear();
for (int i = in[l]; i; i -= lowbit(i)) a1.pb(rt[i]);
for (int i = in[r]; i; i -= lowbit(i)) a2.pb(rt[i]);
for (int i = in[p]; i; i -= lowbit(i)) b1.pb(rt[i]);
for (int i = in[fp]; i; i -= lowbit(i)) b2.pb(rt[i]);
printf("%d\n", b[query(1, cnt, k)]);
}

int main() { memset(head, -1, sizeof head);
cin >> n >> m;
for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
for (int i = 1; i < n; ++i) {
int x, y; scanf("%d%d", &x, &y);
add_edge(x, y); add_edge(y, x);
}
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d", &Q[i].opt, &Q[i].x, &Q[i].y);
if (Q[i].opt) Q[i].z = Q[i].opt, Q[i].opt = 1;
} init_hash(); dep[1] = 1; dfs(1, 0); init_lca();
for (int i = 1; i <= n; ++i)
add(in[i], a[i], 1), add(out[i] + 1, a[i], -1);
for (int i = 1; i <= m; ++i) {
int opt = Q[i].opt;
if (!Q[i].opt) solve_0(i);
else solve_1(i);
}
return 0;
}