Luogu P4707 重返现世

题目描述

https://www.luogu.com.cn/problem/P4707

简要题意:现在有 $n$ 种物品,每个单位时间,有 $\frac{p_i}{m}$ 生成第 $i$ 中原料,求期望多少时间我们可以收集 $k$ 种不同的物品

$k\le n\le 1000,n-k\le 10,\sum p_i = m,m\le 10000$

Solution

我们考虑 $\max-\min$ 容斥,那么我们相当于要求第 $k$ 小,但是我们注意到 $n-k$ 很小,那么我们可以转换为求第 $n-k+1$ 大,然后我们考虑 $\rm kmax_S$ 的式子,$\rm kmax_S=\sum_{\emptyset \neq T\subseteq S}(-1)^{|T|-k}\binom{|T|-1}{k-1}\min_T$,我们知道 $\min_S$ 为 $\frac{m}{\sum_{x \in S}p_x}$,我们考虑 $dp$ 计数这些东西,因为 $min_S$ 的值在分母上,所以我们必须把 $\sum p$ 记录到 $dp$ 式子里,同时我们注意到有一个 $\binom{|T|-1}{k-1}$ 我们考虑递推这个东西,那么我们需要记录三维,$f_{i,j,k}$ 表示第 $i$ 个物品,$\sum p =j$,$k$ 的值为 $k$,这个东西可以 $O(1)$ 转移,时间复杂度 $O(10nm)$,需要注意 $dp$ 的初值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#define maxn 1010
#define maxm 10010
#define maxk 11
using namespace std;

const int p = 998244353;
inline int add(int x, int y) { return (x += y) >= p ? x - p : x; }
inline int mul(int x, int y) { return 1ll * x * y % p; }
inline int add(initializer_list<int> lst) { int s = 0; for (auto t : lst) s = add(s, t); return s; }
inline int mul(initializer_list<int> lst) { int s = 1; for (auto t : lst) s = mul(s, t); return s; }
int pow_mod(int x, int n) {
int s = 1;
for (; n; n >>= 1, x = mul(x, x))
if (n & 1) s = mul(s, x);
return s;
}

int n, m, k, a[maxn];

int f[2][maxm][maxk];

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> k >> m; k = n - k + 1;
for (int i = 1; i <= n; ++i) cin >> a[i];
f[1][a[1]][1] = 1;
for (int i = 1, s = 1, t = 0; i < n; ++i, swap(s, t)) {
fill(f[t][0], f[t][0] + maxm * maxk, 0); f[t][a[i + 1]][1] = 1;
for (int j = 0; j <= m; ++j)
for (int k = 1; k <= ::k; ++k) {
if (!f[s][j][k]) continue;
f[t][j][k] = add(f[t][j][k], f[s][j][k]);
f[t][j + a[i + 1]][k] = add(f[t][j + a[i + 1]][k], p - f[s][j][k]);
if (k < ::k) f[t][j + a[i + 1]][k + 1] = add(f[t][j + a[i + 1]][k + 1], f[s][j][k]);
}
}
int ans = 0;
for (int i = 0; i <= m; ++i) ans = add(ans, mul({ f[n & 1][i][k], pow_mod(i, p - 2), m }));
cout << ans << "\n";
return 0;
}