Luogu P2495 [SDOI2011]消耗战

题目描述

https://www.luogu.com.cn/problem/P2495

简要题意:给定一棵以 $1$ 为根的有根树,边权不为 $1$,现在有 $m$ 次询问,每次询问给定一个点集 $|S|$,保证该点集不包含 $1$,现在要求断掉一些边,使得点集中任意一个点都不与 $1$ 相连,求最小代价

$n\le 2.5\times 10^5,m,\sum |S|\le 5\times 10^5$

Solution

虚树模板题,我们将虚树建出来之后跑 $dp$ 即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#define maxn 250010
#define INF 1000000000
#define ll long long
using namespace std;

int n, m;

vector<pair<int, int>> G[maxn];
struct Edge {
int to, next, w;
} e[maxn]; int c1, head[maxn];
inline void add_edge(int u, int v, int w) {
e[c1].to = v; e[c1].w = w;
e[c1].next = head[u]; head[u] = c1++;
}

int id[maxn], f[maxn][21], mn[maxn][21], dep[maxn];
void pre(int u, int fa) {
static int cnt = 0;
id[u] = ++cnt;
for (auto [v, w] : G[u]) {
if (v == fa) continue;
f[v][0] = u; mn[v][0] = w;
dep[v] = dep[u] + 1; pre(v, u);
}
}

void init_lca() {
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i) {
f[i][j] = f[f[i][j - 1]][j - 1];
mn[i][j] = min(mn[i][j - 1], mn[f[i][j - 1]][j - 1]);
}
}

pair<int, int> get_lca(int x, int y) {
int res = INF;
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; ~i; --i)
if (dep[f[x][i]] >= dep[y]) {
res = min(res, mn[x][i]);
x = f[x][i];
}
if (x == y) return make_pair(x, res);
for (int i = 20; ~i; --i)
if (f[x][i] != f[y][i]) {
res = min({ res, mn[x][i], mn[y][i] });
x = f[x][i]; y = f[y][i];
}
return make_pair(f[x][0], min({ res, mn[x][0], mn[y][0] }));
}

bool col[maxn]; ll dp[maxn];
void dfs(int u, int fa) {
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to, w = e[i].w; if (v == fa) continue;
dfs(v, u);
if (col[v]) dp[u] += w;
else dp[u] += min(dp[v], 1ll * w);
}
}

void clear(int u, int fa) {
col[u] = dp[u] = 0;
for (int i = head[u]; ~i; i = e[i].next)
if (e[i].to != fa) clear(e[i].to, u);
head[u] = -1;
}

int a[maxn];
void solve(int n, int rt) {
for (int i = 1; i <= n; ++i) col[a[i]] = 1;
sort(a + 1, a + n + 1, [](const int &u, const int &v) { return id[u] < id[v]; });
stack<int> S; S.push(rt);
for (int o = 1, u = a[o]; o <= n; u = a[++o]) {
if (u == rt) continue; int lca = get_lca(u, S.top()).first;
while (S.top() != lca) {
int t = S.top(); S.pop();
if (id[S.top()] < id[lca]) S.push(lca);
add_edge(S.top(), t, get_lca(S.top(), t).second);
} S.push(u);
}
while (S.top() != rt) {
int t = S.top(); S.pop();
add_edge(S.top(), t, get_lca(S.top(), t).second);
} dfs(rt, 0);
cout << dp[rt] << "\n";
clear(rt, 0), c1 = 0;
}


int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n;
for (int i = 1; i < n; ++i) {
int x, y, z; cin >> x >> y >> z;
G[x].emplace_back(y, z); G[y].emplace_back(x, z);
} cin >> m; dep[1] = 1; pre(1, 0); init_lca();
for (int i = 1; i <= m; ++i) {
int len; cin >> len;
for (int j = 1; j <= len; ++j) cin >> a[j];
solve(len, 1);
}
return 0;
}