Luogu P2056 [ZJOI2007]捉迷藏

题目描述

https://www.luogu.com.cn/problem/P2056

Solution

按照老规矩,一个点要维护两个数据结构

我们令 $h[u]$ 表示以 $u$ 的所有儿子 $v$ 为起点向 $v$ 的子树内出发的最长链的可删堆,$h[u + n]$ 表示 $u$ 的子树内以 $u$ 为起点的可删堆

注意到这个题中需要求的是全局最长链,所以再维护一个可删堆,这个堆中的值就是 $h[u]$ 的最大值和次大值的和

注意到使用 $O(1)$ $lca$ 可以显著降低常数(大概

时间复杂度 $O(n\log^2 n)$ 加上巨大常数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <iostream>
#include <queue>
#define maxn 100010
#define INF 1000000000
using namespace std;

int n, m, col[maxn];

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int id[maxn * 2], in[maxn], dep[maxn];
void dfs(int u, int fa) {
static int cnt = 0;
id[++cnt] = u; in[u] = cnt;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
dep[v] = dep[u] + 1; dfs(v, u); id[++cnt] = u;
}
}

inline int st_min(int l, int r) { return in[l] < in[r] ? l : r; }

int st[maxn * 2][21], Log[maxn * 2];
void init_st(int n) { Log[0] = -1;
for (int i = 1; i <= n; ++i) Log[i] = Log[i >> 1] + 1;
for (int i = 1; i <= n; ++i) st[i][0] = id[i];
for (int j = 1; j <= 20; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
st[i][j] = st_min(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
}

inline int get_lca(int l, int r) {
l = in[l]; r = in[r]; if (l > r) swap(l, r);
int k = Log[r - l + 1];
return st_min(st[l][k], st[r - (1 << k) + 1][k]);
}

inline int D(int x, int y) { return dep[x] + dep[y] - 2 * dep[get_lca(x, y)]; }

struct Heap {
priority_queue<int> add, del;

inline void push(int x) { add.push(x); }

inline void erase(int x) { del.push(x); }

inline int top() {
while (!del.empty() && del.top() == add.top()) del.pop(), add.pop();
return add.top();
}

inline int sectop() {
int t1 = top(); pop();
int t2 = top();
return push(t1), t2;
}

inline void pop() {
while (!del.empty() && del.top() == add.top()) del.pop(), add.pop();
add.pop();
}

inline bool empty() { return add.size() - del.size() == 0; }

inline int size() { return add.size() - del.size(); }
} h[maxn * 2], ans;

bool vis[maxn];
namespace Calc_sz {
int f[maxn], sz[maxn];
int sum, rt;

void init() { f[rt = 0] = INF; }

void dfs_sz(int u, int fa) {
sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs_sz(v, u); sz[u] += sz[v];
}
}

void dfs(int u, int fa) {
f[u] = 0;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u); f[u] = max(f[u], sz[v]);
} f[u] = max(f[u], sum - sz[u]);
if (f[u] < f[rt]) rt = u;
}

inline int get_rt(int u) {
rt = 0; dfs_sz(u, 0); sum = sz[u];
dfs(u, 0); return rt;
}
}

int fa[maxn];
void divide(int u) {
vis[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (vis[v]) continue;
v = Calc_sz::get_rt(v); fa[v] = u; divide(v);
}
}

inline void add(Heap &t) {
if (t.size() > 1) ans.push(t.top() + t.sectop());
}

inline void del(Heap &t) {
if (t.size() > 1) ans.erase(t.top() + t.sectop());
}

int tot;
void update(int u, int opt) {
int x = u; tot += opt;
del(h[u]); opt == 1 ? h[u].push(0) : h[u].erase(0); add(h[u]);
while (fa[u]) {
del(h[fa[u]]);
if (h[u + n].size()) h[fa[u]].erase(h[u + n].top());
opt == 1 ? h[u + n].push(D(x, fa[u])) : h[u + n].erase(D(x, fa[u]));
if (h[u + n].size()) h[fa[u]].push(h[u + n].top());
add(h[fa[u]]);
u = fa[u];
}
}

inline void solve_1() {
int x; cin >> x;
if (col[x]) update(x, -1);
else update(x, 1);
col[x] ^= 1;
}

inline void solve_2() {
if (tot < 2) cout << tot - 1 << "\n";
else cout << ans.top() << "\n";
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n;
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
} dep[1] = 1; dfs(1, 0); init_st(2 * n - 1);
Calc_sz::init(); divide(Calc_sz::get_rt(1));
for (int i = 1; i <= n; ++i) col[i] = 1, update(i, 1);
cin >> m;
for (int i = 1; i <= m; ++i) {
char s[3]; cin >> s + 1;
if (s[1] == 'C') solve_1();
else solve_2();
}
return 0;
}