Luogu P3233 [HNOI2014]世界树

题目描述

https://www.luogu.com.cn/problem/P3233

简要题意:给定一棵 $n$ 个节点的树,现在有 $m$ 次询问,每次询问给出 $k_i$ 个关键点,然后将树上的每个点划分到离他最近的关键点,如果距离两个关键点相同,则划分给编号较小的关键点,输出每个关键点被分配到多少点

$n,m,\sum k_i\le 3\times 10^5$

Solution

我们考虑建出关键点的虚树,然后通过两次 $dfs$,求出离每个点最近的关键点

然后我们考虑计算每条边的贡献,如果这条边的两个端点属于同一个关键点,那么这条边一定也属于这个关键点,否则我们二分分界点即可,处理时需要一些技巧

时间复杂度 $O(n\log^2 n)$,如果优化求距离到 $O(1)$,可以做到 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#define maxn 300010
#define INF 1000000000
#define ll long long
using namespace std;

int n, m;

vector<int> G[maxn];
struct Edge {
int to, next;
} e[maxn]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int f[maxn][21], dep[maxn], id[maxn], sz[maxn];
void pre(int u, int fa) {
static int cnt = 0;
id[u] = ++cnt; sz[u] = 1;
for (auto v : G[u]) {
if (v == fa) continue;
f[v][0] = u; dep[v] = dep[u] + 1;
pre(v, u); sz[u] += sz[v];
}
}

void init_lca() {
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i) f[i][j] = f[f[i][j - 1]][j - 1];
}

int get_lca(int x, int y) {
int res = INF;
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; ~i; --i)
if (dep[f[x][i]] >= dep[y]) x = f[x][i];
if (x == y) return x;
for (int i = 20; ~i; --i)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}

inline int D(int x, int y) { return dep[x] + dep[y] - 2 * dep[get_lca(x, y)]; }

int w[maxn], sum[maxn], bl[maxn]; bool col[maxn];
void dfs(int u, int fa) {
if (col[u]) bl[u] = u;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dfs(v, u);
if (!bl[u] || D(bl[u], u) > D(bl[v], u) || D(bl[u], u) == D(bl[v], u) && bl[u] > bl[v])
bl[u] = bl[v];
}
}

void Dfs(int u, int fa) {
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
if (!bl[v] || D(bl[v], v) > D(bl[u], v) || D(bl[v], v) == D(bl[u], v) && bl[v] > bl[u])
bl[v] = bl[u];
Dfs(v, u);
}
}

void work(int u, int v) {
int son = v;
for (int i = 20; ~i; --i)
if (dep[f[son][i]] > dep[u]) son = f[son][i];
w[u] -= sz[son];
if (bl[u] == bl[v]) return sum[bl[u]] += sz[son] - sz[v], void();
int ans = v;
for (int i = 20; ~i; --i) {
int x = f[ans][i]; if (dep[x] <= dep[u]) continue;
int d1 = D(x, bl[v]), d2 = D(x, bl[u]);
if (d1 < d2 || d1 == d2 && bl[v] < bl[u]) ans = x;
}
sum[bl[u]] += sz[son] - sz[ans];
sum[bl[v]] += sz[ans] - sz[v];
}

void calc(int u, int fa) {
w[u] = sz[u];
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
work(u, v); calc(v, u);
} sum[bl[u]] += w[u];
}

void clear(int u, int fa) {
col[u] = bl[u] = sum[u] = w[u] = 0;
for (int i = head[u]; ~i; i = e[i].next)
if (e[i].to != fa) clear(e[i].to, u);
head[u] = -1;
}

int a[maxn], tp[maxn];
void solve(int n, int rt) {
for (int i = 1; i <= n; ++i) col[a[i]] = 1, tp[i] = a[i];
sort(a + 1, a + n + 1, [](const int &u, const int &v) { return id[u] < id[v]; });
stack<int> S; S.push(rt);
for (int o = 1, u = a[o]; o <= n; u = a[++o]) {
if (u == rt) continue; int lca = get_lca(u, S.top());
while (S.top() != lca) {
int t = S.top(); S.pop();
if (id[S.top()] < id[lca]) S.push(lca);
add_edge(S.top(), t);
} S.push(u);
}
while (S.top() != rt) {
int t = S.top(); S.pop();
add_edge(S.top(), t);
} dfs(rt, 0); Dfs(rt, 0); calc(rt, 0);
for (int i = 1; i <= n; ++i) cout << sum[tp[i]] << " \n"[i == n];
clear(rt, 0), c1 = 0;
}


int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n;
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
G[x].push_back(y); G[y].push_back(x);
} cin >> m; dep[1] = 1; pre(1, 0); init_lca();
for (int i = 1; i <= m; ++i) {
int len; cin >> len;
for (int j = 1; j <= len; ++j) cin >> a[j];
solve(len, 1);
}
return 0;
}