第46屆ICPC 東亞洲區域賽(澳門) E Pass the Ball

题目描述

简要题意:现在有 $n$ 个人,第 $i$ 个人手里有一个数字 $i$,现在开始传数字,第 $i$ 个人会将他手里的数字传给 $p_i$,$p_i$ 是一个排列,同时现在有 $m$ 次询问,每次询问给出一个整数 $x$,求传了 $x$ 轮后,每个人的编号乘上他手里的数字的和

$n\le 2\times 10^5,m\le 10^5,x\le 10^9$

https://ac.nowcoder.com/acm/contest/31454/E

Solution

注意到传球形成了若干个简单环,我们对每个简单环单独计算贡献

我们设当前环的大小为 $n$,第 $i$ 个人的编号为 $a_i$,如果传了 $k$ 轮,那么总贡献是类似于 $\sum a_i\times a_{i+k}$ 的东西,稍作考虑后可以发现这是一个类似减法卷积的东西,将长度扩大到 $2n$,然后再稍作处理直接做减法卷积即可,注意到总贡献刚好是 $1e15$ 级别的,所以我们是可以用 $FFT$ 的

然后考虑处理询问,注意到大小不同的环的个数只有 $O(\sqrt n)$,所以每次询问暴力所有环的大小即可

时间复杂度 $O(n\log n+q\sqrt n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#define maxn 200010
#define ll long long
using namespace std;

typedef std::vector<ll> Poly;
namespace Pol {
const int N = 2100000; // 100MB
const double pi = acos(-1);
struct Complex {
double x, y;

Complex(double x = 0, double y = 0) : x(x), y(y) {}

friend Complex operator + (const Complex &u, const Complex &v) { return Complex(u.x + v.x, u.y + v.y); }
friend Complex operator - (const Complex &u, const Complex &v) { return Complex(u.x - v.x, u.y - v.y); }
friend Complex operator * (const Complex &u, const Complex &v) {
return Complex(u.x * v.x - u.y * v.y, u.x * v.y + u.y * v.x);
}
} a[N], b[N];

int P[N];
void init_P(int n) {
int l = 0; while ((1 << l) < n) ++l;
for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1);
}
vector<Complex> init_W(int n) {
vector<Complex> w(n); w[1] = 1;
for (int i = 2; i < n; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
Complex wn(cos(pi / i), sin(pi / i));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = w1[j] * wn;
}
return w;
} auto w = init_W(1 << 21);
void DIF(Complex *a, int n) {
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k];
a[i + j + k] = (x - y) * w[k + j], a[i + j] = x + y;
}
}
void IDIT(Complex *a, int n) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
Complex x = a[i + j], y = a[i + j + k] * w[k + j];
a[i + j + k] = x - y, a[i + j] = x + y;
}
for (int i = 0; i < n; ++i) a[i].x /= n, a[i].y /= n;
reverse(a + 1, a + n);
}
Poly Mul(const Poly &A, const Poly &B, int n1, int n2) {
int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n);
for (int i = 0; i < n1; ++i) a[i] = Complex(A[i], 0);
for (int i = 0; i < n2; ++i) b[i] = Complex(B[i], 0);
fill(a + n1, a + n, Complex(0, 0)); fill(b + n2, b + n, Complex(0, 0));
DIF(a, n); DIF(b, n);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
IDIT(a, n); Poly ans(n1 + n2 - 1); for (int i = 0; i < n1 + n2 - 1; ++i) ans[i] = (int) (a[i].x + 0.5);
return ans;
}
Poly MMul(const Poly &A, const Poly &B, int len) { //减法卷积
int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n);
for (int i = 0; i < len; ++i) a[i] = Complex(A[i], 0);
for (int i = 0; i < len; ++i) b[i] = Complex(B[i], 0);
fill(a + len, a + n, Complex(0, 0)); fill(b + len, b + n, Complex(0, 0)); reverse(b, b + len);
DIF(a, n); DIF(b, n);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
IDIT(a, n); Poly ans(len); for (int i = 0; i < len; ++i) ans[i] = (ll) (a[i].x + 0.5);
reverse(ans.begin(), ans.end()); return ans;
}
} // namespace Pol

int n, m, a[maxn];
bool use[maxn];

int mp[maxn], cnt;
vector<int> d;
vector<ll> w[maxn];
void solve(vector<int> a) {
int n = a.size(); Poly A(2 * n), B(2 * n);
for (int i = 0; i < n; ++i) A[i] = a[i];
for (int i = n; i < 2 * n; ++i) A[i] = A[i - n];
for (int i = n; i < 2 * n; ++i) B[i] = a[i - n];
Poly C = Pol::MMul(A, B, 2 * n);
if (!mp[n]) mp[n] = ++cnt, d.push_back(n), w[mp[n]].resize(n);
for (int i = 0; i < n; ++i) w[mp[n]][i] += C[i];
}


int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 1; i <= n; ++i) {
int t = i; vector<int> vec; if (use[i]) continue;
do {
vec.push_back(t); use[t] = 1;
t = a[t];
} while (t != i);
solve(vec);
}
for (int i = 1; i <= m; ++i) {
int x; ll ans = 0; cin >> x;
for (auto t : d) ans += w[mp[t]][x % t];
cout << ans << "\n";
}
return 0;
}