voiddfs(int u, int fa){ int tmp = sum[c[u]]; sz[u] = 1; for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; dfs(v, u); sz[u] += sz[v]; } sum[c[u]] = sz[u] + tmp; }
voiddfs(int u, int fa){ int tmp = sum[c[u]]; sum[c[fa]] = 0; sz[u] = 1; for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; dfs(v, u); sz[u] += sz[v]; } sum[c[u]] = sz[u]; a[u] = sum[c[fa]]; sum[c[u]] += tmp; }
voidDfs(int u, int fa){ int tmp1 = sum[c[u]], tmp2 = sum[c[fa]]; if (u != 1) { f[u] = f[fa] - sz[u] + a[u] + n - sum[c[u]]; sum[c[u]] = n; sum[c[fa]] = n - sz[u] + a[u]; } for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; Dfs(v, u); } sum[c[u]] = tmp1; sum[c[fa]] = tmp2; }
#include<iostream> #include<algorithm> #include<cstring> #include<vector> #include<map> #include<set> #include<queue> #define maxn 100010 #define ll long long #define ull unsigned long long usingnamespacestd;
int n, m, c[maxn];
structEdge { int to, next; } e[maxn * 2]; int c1, head[maxn]; inlinevoidadd_edge(int u, int v){ e[c1].to = v; e[c1].next = head[u]; head[u] = c1++; }
int sz[maxn]; ll sum[maxn], a[maxn], f[maxn]; voiddfs(int u, int fa){ int tmp = sum[c[u]]; sum[c[fa]] = 0; sz[u] = 1; for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; dfs(v, u); sz[u] += sz[v]; } sum[c[u]] = sz[u]; a[u] = sum[c[fa]]; sum[c[u]] += tmp; }
voidDfs(int u, int fa){ int tmp1 = sum[c[u]], tmp2 = sum[c[fa]]; if (u != 1) { f[u] = f[fa] - sz[u] + a[u] + n - sum[c[u]]; sum[c[u]] = n; sum[c[fa]] = n - sz[u] + a[u]; } for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; Dfs(v, u); } sum[c[u]] = tmp1; sum[c[fa]] = tmp2; }
intmain(){ fill(head, head + maxn, -1); ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
cin >> n; for (int i = 1; i <= n; ++i) cin >> c[i]; for (int i = 1; i < n; ++i) { int x, y; cin >> x >> y; add_edge(x, y); add_edge(y, x); } dfs(1, 0); for (int i = 1; i <= 1e5; ++i) f[1] += sum[i]; Dfs(1, 0); for (int i = 1; i <= n; ++i) cout << f[i] << "\n"; return0; }