2020-2021 “Orz Panda” Cup Programming Contest G Gery's Problem and Orz Pandas(树链剖分)

题目描述

https://codeforces.com/gym/102870/problem/G

Solution

我们令 $w[v]=sz[u]-sz[v]+1$,那么一条路径上的权值和就是 $\sum i(l-i)w[i]$

然后我们发现这个东西可以拿线段树维护并且支持区间合并,所以我们就暴力直接码树剖线段树

但是我们还要维护这个东西反过来,并且还要一堆细节,导致常熟巨大,无法通过此题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <iostream>
#include <cstdio>
#define maxn 100010
#define ll long long
#define cs const Seg
using namespace std;

const int p = 998244353;
inline int add(int x, int y) { return x + y >= p ? x + y - p : x + y; }

int n, m;

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int dep[maxn], f[maxn], sz[maxn], son[maxn], fa[maxn][21];
void dfs(int u, int fa) {
int Max = 0; sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dep[v] = dep[u] + 1; f[v] = u; ::fa[v][0] = u; dfs(v, u);
sz[u] += sz[v]; if (sz[v] > Max) Max = sz[v], son[u] = v;
}
}

int id[maxn], top[maxn], c2, bl[maxn], w[maxn];
void dfs(int u, int fa, int topf) {
top[u] = topf; id[u] = ++c2; bl[c2] = u;
if (son[u]) w[son[u]] = sz[u] - sz[son[u]], dfs(son[u], u, topf);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == son[u] || v == fa) continue;
w[v] = sz[u] - sz[v]; dfs(v, u, v);
}
}

inline void init_lca() {
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i) fa[i][j] = fa[fa[i][j - 1]][j - 1];
}

inline int get_lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
x = f[top[x]];
}
return dep[x] < dep[y] ? x : y;
}

inline int jump(int u, int v) {
for (int i = 20; ~i; --i)
if (dep[fa[u][i]] > dep[v]) u = fa[u][i];
return u;
}

#define lc i << 1
#define rc i << 1 | 1
struct Seg {
ll v, sum, mul, len;

ll V, Sum, Mul, Len;

Seg() { v = sum = mul = len = V = Sum = Mul = Len = 0; }
Seg(ll _v, ll _sum, ll _mul, ll _len, ll _V, ll _Sum, ll _Mul, ll _Len) {
v = _v; sum = _sum; mul = _mul; len = _len;
V = _V; Sum = _Sum; Mul = _Mul; Len = _Len;
}
} T[maxn * 4];
inline Seg maintain(cs &ls, cs &rs) {
Seg o;
o.v = (ls.v + rs.len * ls.mul + rs.v + ls.len * (rs.len + 1) % p * rs.sum - ls.len * rs.mul) % p;
o.sum = add(ls.sum, rs.sum);
o.mul = (ls.mul + rs.mul + ls.len * rs.sum) % p;
o.len = ls.len + rs.len;

o.V = (rs.V + ls.Len * rs.Mul + ls.V + rs.Len * (ls.Len + 1) % p * ls.Sum - rs.Len * ls.Mul) % p;
o.Sum = add(rs.Sum, ls.Sum);
o.Mul = (rs.Mul + ls.Mul + rs.Len * ls.Sum) % p;
o.Len = rs.Len + ls.Len;
return o;
}

void build(int i, int l, int r) {
if (l == r) return (void) (T[i] = Seg(w[bl[l]], w[bl[l]], w[bl[l]], 1, w[bl[l]], w[bl[l]], w[bl[l]], 1));
int m = l + r >> 1;
build(lc, l, m); build(rc, m + 1, r);
T[i] = maintain(T[lc], T[rc]);
}

Seg query(int i, int l, int r, int L, int R) {
if (l > R || r < L) return Seg(0, 0, 0, 0, 0, 0, 0, 0);
if (L <= l && r <= R) return T[i];
int m = l + r >> 1; Seg ls = query(lc, l, m, L, R), rs = query(rc, m + 1, r, L, R);
return maintain(ls, rs);
}

int query(int x, int y) {
int g = get_lca(x, y), u = jump(x, g), v = jump(y, g); Seg l, r;
if (g == x) {
while (top[y] != top[v]) {
r = maintain(query(1, 1, n, id[top[y]], id[y]), r);
y = f[top[y]];
}
r = maintain(query(1, 1, n, id[v] + 1, id[y]), r);
return r.v;
}

if (g == y) {
while (top[x] != top[u]) {
l = maintain(query(1, 1, n, id[top[x]], id[x]), l);
x = f[top[x]];
}
l = maintain(query(1, 1, n, id[u] + 1, id[x]), l);
return l.V;
}

while (top[x] != top[u]) {
l = maintain(query(1, 1, n, id[top[x]], id[x]), l);
x = f[top[x]];
}
l = maintain(query(1, 1, n, id[u] + 1, id[x]), l);

while (top[y] != top[v]) {
r = maintain(query(1, 1, n, id[top[y]], id[y]), r);
y = f[top[y]];
}
r = maintain(query(1, 1, n, id[v] + 1, id[y]), r);

r = maintain(Seg(n - sz[u] - sz[v], n - sz[u] - sz[v], n - sz[u] - sz[v], 1, 0, 0, 0, 0), r);
return (l.V + r.len * l.Mul + r.v + l.Len * (r.len + 1) % p * r.sum - l.Len * r.mul) % p;
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
} dfs(1, 0); dfs(1, 0, 1); init_lca(); build(1, 1, n);
for (int i = 1; i <= m; ++i) {
int x, y; cin >> x >> y;
cout << (query(x, y) + p) % p << "\n";
}
return 0;
}