Luogu P2664 树上游戏

题目描述

https://www.luogu.com.cn/problem/P2664

简要题意:给定一棵 $n$ 个节点的树,每个点有一个颜色 $c_i$,令 $d(i,j)$ 表示 $i$ 到 $j$ 的简单路径上的颜色种数, 令 $f_i=\sum_{j=1}^nd(i,j)$,求所有的 $f_i$

$n,c_i\le 10^5$

Solution

首先我们考虑一种可以求根节点的答案的方法,大概就是维护 $sum_c$ 表示 $c$ 这个颜色贡献,做法如下:

1
2
3
4
5
6
7
8
void dfs(int u, int fa) {
int tmp = sum[c[u]]; sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dfs(v, u); sz[u] += sz[v];
}
sum[c[u]] = sz[u] + tmp;
}

容易得到 $f_{rt}=\sum sum_c$,我们考虑换根来求其他点的 $f$,容易发现从 $u$ 换到其儿子 $v$ 时,只有 $c_u$ 和 $c_v$ 这两种颜色的贡献会发生变化,具体的,$c_v$ 的贡献会变成 $n$,相当于在之前的基础上减掉以$u$ 为根时除 $v$ 的以外的其他子树内 $c_v$ 的贡献,$c_u$ 的贡献首先应该减去 $sz_v$, 同时要加上 $v$ 的子树内不考虑 $u$ 时 $c_u$ 这个颜色的贡献

对于后者容易发现我们上面那个朴素做法的 $sum$ 与这个东西类似, 我们令其为 $a_v$

1
2
3
4
5
6
7
8
9
10
11
void dfs(int u, int fa) {
int tmp = sum[c[u]];
sum[c[fa]] = 0; sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dfs(v, u); sz[u] += sz[v];
}
sum[c[u]] = sz[u];
a[u] = sum[c[fa]];
sum[c[u]] += tmp;
}

对于前者,我们直接在换根的时候,同时维护 $sum_c$ 表示以当前点为根 $c$ 的贡献,换根的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
void Dfs(int u, int fa) {
int tmp1 = sum[c[u]], tmp2 = sum[c[fa]];
if (u != 1) {
f[u] = f[fa] - sz[u] + a[u] + n - sum[c[u]];
sum[c[u]] = n; sum[c[fa]] = n - sz[u] + a[u];
}
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
Dfs(v, u);
}
sum[c[u]] = tmp1; sum[c[fa]] = tmp2;
}

时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <map>
#include <set>
#include <queue>
#define maxn 100010
#define ll long long
#define ull unsigned long long
using namespace std;

int n, m, c[maxn];

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int sz[maxn];
ll sum[maxn], a[maxn], f[maxn];
void dfs(int u, int fa) {
int tmp = sum[c[u]];
sum[c[fa]] = 0; sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dfs(v, u); sz[u] += sz[v];
}
sum[c[u]] = sz[u];
a[u] = sum[c[fa]];
sum[c[u]] += tmp;
}

void Dfs(int u, int fa) {
int tmp1 = sum[c[u]], tmp2 = sum[c[fa]];
if (u != 1) {
f[u] = f[fa] - sz[u] + a[u] + n - sum[c[u]];
sum[c[u]] = n; sum[c[fa]] = n - sz[u] + a[u];
}
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
Dfs(v, u);
}
sum[c[u]] = tmp1; sum[c[fa]] = tmp2;
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n;
for (int i = 1; i <= n; ++i) cin >> c[i];
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
} dfs(1, 0);
for (int i = 1; i <= 1e5; ++i) f[1] += sum[i];
Dfs(1, 0);
for (int i = 1; i <= n; ++i) cout << f[i] << "\n";
return 0;
}