voidadd(int u){ for (int i = in[u]; i <= out[u]; ++i) { int c = col[bl[i]], cnt = ++cnt1[c]; if (cnt > mx) mx = cnt; ++cnt2[cnt]; sum[cnt] += c; cnt2[cnt - 1]--; sum[cnt - 1] -= c; } }
voiddel(int u){ for (int i = in[u]; i <= out[u]; ++i) { int c = col[bl[i]], cnt = --cnt1[c]; if (--cnt2[cnt + 1] == 0 && cnt + 1 == mx) --mx; sum[cnt + 1] -= c; cnt2[cnt]++; sum[cnt] += c; } }
ll ans[maxn]; voiddfs(int u, int fa){ for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa || v == son[u]) continue ; dfs(v, u); del(v); } if (son[u]) dfs(son[u], u); for (int i = head[u]; ~i; i = e[i].next) { int v = e[i].to; if (v == fa || v == son[u]) continue; add(v); } int c = col[u], cnt = ++cnt1[c]; if (cnt > mx) mx = cnt; ++cnt2[cnt]; sum[cnt] += c; cnt2[cnt - 1]--; sum[cnt - 1] -= c; ans[u] = sum[mx]; }