Luogu P6329 【模板】点分树 | 震波

题目描述

https://www.luogu.com.cn/problem/P6329

简要题意:给定一棵 $n$ 个点的树,边权为 $1$,每个点有一个点权 $a_i$,现在有 $m$ 次操作,第一种操作求距离点 $x$ 不超过 $y$ 的点的权值和,另一种操作修改某个点的权值,强制在线

$n,m\le 10^5$

Solution

点分树模板题,每次查询相当于一个前缀,所以我们可以使用树状数组

总的空间复杂度是 $O(n\log n)$ 的,因为所有点的子树内的最大深度是 $O(n\log n)$ 的,总的时间复杂度为 $O(n\log ^2n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <iostream>
#include <cstdio>
#define maxn 100010
#define INF 1000000000
#define lowbit(i) ((i) & (-i))
using namespace std;

int n, m, a[maxn];

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int R[maxn]; bool vis[maxn];
struct Calc_sz {
int f[maxn], sz[maxn];
int sum, rt, maxdp;

void init() { f[rt = 0] = INF; }

void dfs_sz(int u, int fa) {
sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs_sz(v, u); sz[u] += sz[v];
}
}

void dfs(int u, int fa) {
f[u] = 0;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u); f[u] = max(f[u], sz[v]);
} f[u] = max(f[u], sum - sz[u]);
if (f[u] < f[rt]) rt = u;
}

void dfs(int u, int fa, int d) {
maxdp = max(maxdp, d);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u, d + 1);
}
}

inline int get_rt(int u) {
rt = maxdp = 0; dfs_sz(u, 0); sum = sz[u];
dfs(u, 0); dfs(rt, 0, 0); R[rt] = maxdp; return rt;
}
} _;

int id[maxn * 2], in[maxn], dep[maxn];
void dfs(int u, int fa) {
static int cnt = 0;
id[++cnt] = u; in[u] = cnt;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dep[v] = dep[u] + 1; dfs(v, u); id[++cnt] = u;
}
}

inline int st_min(int l, int r) { return in[l] < in[r] ? l : r; }

int st[maxn * 2][20], Log[maxn * 2];
void init_st(int n) { Log[0] = -1;
for (int i = 1; i <= n; ++i) Log[i] = Log[i >> 1] + 1;
for (int i = 1; i <= n; ++i) st[i][0] = id[i];
for (int j = 1; j <= 20; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
st[i][j] = st_min(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
}

inline int get_lca(int l, int r) {
l = in[l]; r = in[r]; if (l > r) swap(l, r);
int k = Log[r - l + 1];
return st_min(st[l][k], st[r - (1 << k) + 1][k]);
}

inline int D(int x, int y) { return dep[x] + dep[y] - 2 * dep[get_lca(x, y)]; }

struct Bit {
int *Bit, n;

void init(int _n) { static int P[maxn * 60], *ret = P; Bit = ret; ret += (n = _n) + 1; }

void add(int i, int v) { ++i; while (i <= n) Bit[i] += v, i += lowbit(i); }

int get_sum(int i) {
int s = 0; i = min(i + 1, n);
while (i) s += Bit[i], i -= lowbit(i);
return s;
}
} Bit[maxn * 2];

int fa[maxn];
void divide(int u) {
vis[u] = 1; Bit[u].init(R[u] + 1);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (vis[v]) continue;
v = _.get_rt(v); fa[v] = u; Bit[v + n].init(R[u] + 1);
divide(v);
}
}

void update(int u, int v) {
int x = u; Bit[u].add(0, v);
while (fa[u]) {
int dis = D(fa[u], x);
Bit[fa[u]].add(dis, v);
Bit[u + n].add(dis, v);
u = fa[u];
}
}

int query(int u, int k) {
int x = u, s = Bit[u].get_sum(k);
while (fa[u]) {
int dis = D(fa[u], x); if (dis > k) { u = fa[u]; continue; }
s += Bit[fa[u]].get_sum(k - dis);
s -= Bit[u + n].get_sum(k - dis);
u = fa[u];
}
return s;
}

int lans;
inline void solve_1() {
int x, y; cin >> x >> y; x ^= lans; y ^= lans;
cout << (lans = query(x, y)) << "\n";
}

inline void solve_2() {
int x, y; cin >> x >> y; x ^= lans; y ^= lans;
update(x, y - a[x]); a[x] = y;
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
} dep[1] = 1; dfs(1, 0); init_st(2 * n - 1);
_.init(); divide(_.get_rt(1));
for (int i = 1; i <= n; ++i) update(i, a[i]);
for (int i = 1; i <= m; ++i) {
int opt; cin >> opt;
if (opt == 0) solve_1();
else solve_2();
}
return 0;
}