poj 1741 Tree

题目描述

http://poj.org/problem?id=1741

求出点对 $dis(u,v)\le k$ 的数量

Solution

直接点分,每层排序然后双指针

时间复杂度 $O(n\log^2 n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define maxn 10010
#define INF 1000000000
using namespace std;

int n, m;

struct Edge {
int to, next, w;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v, int w) {
e[c1].to = v; e[c1].w = w;
e[c1].next = head[u]; head[u] = c1++;
}

bool vis[maxn];
struct Calc_sz {
int sz[maxn], f[maxn], sum, rt;
void init() { f[rt = 0] = INF; }

void dfs_sz(int u, int fa) {
sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs_sz(v, u); sz[u] += sz[v];
}
}

void dfs(int u, int fa) {
f[u] = 0;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || vis[v]) continue;
dfs(v, u); f[u] = max(f[u], sz[v]);
} f[u] = max(f[u], sum - sz[u]);
if (f[u] < f[rt]) rt = u;
}

inline int get_rt(int u) {
rt = 0; dfs_sz(u, 0); sum = sz[u]; dfs(u, 0); return rt;
}
} _;

int a[maxn], c2;
void dfs(int u, int fa, int D) {
a[++c2] = D;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to, w = e[i].w; if (v == fa || vis[v]) continue;
dfs(v, u, D + w);
}
}

int calc(int u, int D) {
c2 = 0; dfs(u, 0, D); sort(a + 1, a + c2 + 1);
int s = 0, j = c2;
for (int i = 1; i < j; ++i) {
while (i < j && a[i] + a[j] > m) --j;
s += j - i;
}
return s;
}

int ans;
void divide(int u) {
ans += calc(u, 0); vis[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to, w = e[i].w; if (vis[v]) continue;
ans -= calc(v, w); divide(_.get_rt(v));
}
}

void init() {
memset(vis, 0, sizeof vis); memset(head, -1, sizeof head);
c1 = ans = 0;
}

void work() { init();
for (int i = 1; i < n; ++i) {
int x, y, z; scanf("%d%d%d", &x, &y, &z);
add_edge(x, y, z); add_edge(y, x, z);
} _.init(); divide(_.get_rt(1));
cout << ans << endl;
}


int main() {
while (cin >> n >> m && n + m) work();
return 0;
}