CF 856D Masha and Cactus

题目描述

https://www.luogu.com.cn/problem/CF856D

简要题意:给定一棵 $n$ 个点的树以及 $m$ 条树上的链,每条链有一个价值,要求选择若干条链,使得每个点至多被包含在一个链中,且价值最大

$n,m\le 2\times 10^5$

Solution

我们将每条链扔到 $lca$ 上,然后做树形 $dp$

我们令 $f[u]$ 表示 $u$ 子树内已经选择完的最大价值,$g[u]=\sum f[v]$

如果 $u$ 点没有链,那么 $f[u]=g[u]$

如果 $u$​ 点有一条链,那么 $f[u]$​ 应该取 $g[u]+\sum g[v]-f[v]$​,其中 $v$​ 是这条链上的点

这个东西做一个差分后相当于区间加和单点查询,可以用树状数组实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <iostream>
#include <vector>
#include <tuple>
#define maxn 200010
#define lowbit(i) ((i) & (-i))
using namespace std;

int n, m;

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int in[maxn], out[maxn], bl[maxn], dep[maxn], F[maxn][21];
void Dfs(int u, int fa) {
static int cnt = 0;
in[u] = ++cnt; bl[cnt] = u;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dep[v] = dep[u] + 1; F[v][0] = u; Dfs(v, u);
} out[u] = cnt;
}

void init_lca() {
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i) F[i][j] = F[F[i][j - 1]][j - 1];
}

int get_lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; ~i; --i)
if (dep[F[x][i]] >= dep[y]) x = F[x][i];
if (x == y) return x;
for (int i = 20; ~i; --i)
if (F[x][i] != F[y][i])
x = F[x][i], y = F[y][i];
return F[x][0];
}

vector<tuple<int, int, int>> A[maxn];

int Bit[maxn][2];
void add(int i, int v, int o) { while (i <= n) Bit[i][o] += v, i += lowbit(i); }

void update(int l, int r, int v, int o) { add(l, v, o); add(r + 1, -v, o); }

int get_sum(int i, int o) {
int s = 0;
while (i) s += Bit[i][o], i -= lowbit(i);
return s;
}

int f[maxn], g[maxn];
void dfs(int u, int fa) {
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dfs(v, u); g[u] += f[v];
} f[u] = g[u];
for (auto t : A[u]) {
int x, y, w; tie(x, y, w) = t;
f[u] = max(f[u], g[u] + get_sum(in[x], 1) + get_sum(in[y], 1) - get_sum(in[x], 0) - get_sum(in[y], 0) + w);
}
update(in[u], out[u], f[u], 0); update(in[u], out[u], g[u], 1);
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 2; i <= n; ++i) {
int x; cin >> x;
add_edge(x, i);
} dep[1] = 1; Dfs(1, 0); init_lca();
for (int i = 1; i <= m; ++i) {
int x, y, z; cin >> x >> y >> z;
A[get_lca(x, y)].emplace_back(x, y, z);
} dfs(1, 0);
cout << f[1] << "\n";
return 0;
}