Luogu P5311 [Ynoi2011] 成都七中

题目描述

https://www.luogu.com.cn/problem/P5311

简要题意:给定一棵 $n$ 个点的无根树,每个点有一个颜色 $c_i$,现在有 $m$ 次询问,每次询问给定 $[l,r]$ 和 $x$,问仅保留编号在 $[l,r]$ 内的点和它们之间的边的话,$x$ 所在的连通块有多少种不同的颜色

$n,m\le 10^5,l\le x\le r$

Solution

根据点分树的性质,对于每一个连通块都存在一个点分树上深度最小的点,满足点分过程中该点作为根,且该点的子树内包含该连通块中的所有点

那么我们可以将询问挂到对应的点分树上,具体来说,我们先把询问挂到 $x$ 上,在点分的过程中,我们每次点分到 $x$ 都一路跳父亲到根,判断 $x$ 到根的路径上是否包含的点都属于 $[l,r]$,如果都属于就说明当前的根就是这个连通块在点分树上深度最小的那个点

然后我们考虑如何回答询问,对于一个点分树中的点,每个点都应视为一个区间 $[l,r]$,$l$ 和 $r$ 分别是该点到根的路径上的点的编号的最小值和最大值,我们考虑将询问和点都按照右端点离线,那么现在相当于每次加入一个颜色线段,然后查询以 $l$ 为左端点包含多少种颜色线段,这个东西的维护方式就是对于每种颜色线段我们只记录左端点最大的那个,可以使用树状数组实现

时间复杂度 $O(n\log^2 n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <iostream>
#include <vector>
#include <algorithm>
#include <tuple>
#define maxn 100010
#define lowbit(i) ((i) & (-i))
using namespace std;

struct Edge {
int to, next, w;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

bool vis[maxn]; int mxd[maxn];
namespace DF {
int f[maxn], sz[maxn];
int sum, rt, maxdp;

inline void init(int n) { f[rt = 0] = n + 1; fill(vis + 1, vis + n + 1, false); }
void dfs_sz(int u, int fa) {
sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs_sz(v, u); sz[u] += sz[v];
}
}
void dfs_maxdp(int u, int fa, int d) {
maxdp = max(maxdp, d);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs_maxdp(v, u, d + 1);
}
}
void dfs(int u, int fa) {
f[u] = 0;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u); f[u] = max(f[u], sz[v]);
} f[u] = max(f[u], sum - sz[u]);
if (f[u] < f[rt]) rt = u;
}
inline int get_rt(int u) {
rt = maxdp = 0; dfs_sz(u, 0); sum = sz[u];
dfs(u, 0); dfs_maxdp(rt, 0, 1); mxd[rt] = maxdp; return rt;
}
};

int n, m, a[maxn];

int Bit[maxn];
void add(int i, int v) { while (i) Bit[i] += v, i -= lowbit(i); }
void clear(int i) { while (i) Bit[i] = 0, i -= lowbit(i); }
int get_sum(int i) {
int s = 0;
while (i <= n) s += Bit[i], i += lowbit(i);
return s;
}

vector<tuple<int, int, int>> A[maxn], Q, vec;
int ans[maxn], last[maxn];
void dfs(int u, int fa, int mn, int mx) {
for (auto [l, r, id] : A[u])
if (l <= mn && mx <= r && !ans[id]) Q.emplace_back(l, r, id);
vec.emplace_back(mn, mx, u);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u, min(mn, v), max(mx, v));
}
}

inline bool cmp(const tuple<int, int, int> &u, const tuple<int, int, int> &v) {
return get<1>(u) < get<1>(v);
}

void solve(int u) {
vec.clear(); Q.clear(); vec.emplace_back(u, u, u);
for (auto [l, r, id] : A[u])
if (!ans[id]) Q.emplace_back(l, r, id);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (vis[v]) continue;
dfs(v, u, min(u, v), max(u, v));
}
sort(vec.begin(), vec.end(), cmp); sort(Q.begin(), Q.end(), cmp);
auto it = vec.begin();
for (auto [l, r, id] : Q) {
while (it != vec.end() && get<1>(*it) <= r) {
auto [l, r, k] = *it;
if (last[a[k]] < l) add(last[a[k]], -1), add(l, 1), last[a[k]] = l;
it = next(it);
} ans[id] = get_sum(l);
}
for (auto [l, r, k] : vec) last[a[k]] = 0, clear(l);
}

void divide(int u) {
vis[u] = 1; solve(u);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (vis[v]) continue;
v = DF::get_rt(v); divide(v);
}
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
}
for (int i = 1; i <= m; ++i) {
int x, y, z; cin >> x >> y >> z;
A[z].emplace_back(x, y, i);
}
DF::init(n); divide(DF::get_rt(1));
for (int i = 1; i <= m; ++i) cout << ans[i] << "\n";
return 0;
}