1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
| #include <iostream> #include <vector> #include <algorithm> #define maxn 200010 #define ll long long using namespace std;
const int p = 998244353; inline int add(int x, int y) { return (x += y) >= p ? x - p : x; } inline int mul(int x, int y) { return 1ll * x * y % p; } inline int add(initializer_list<int> lst) { int s = 0; for (auto t : lst) s = add(s, t); return s; } inline int mul(initializer_list<int> lst) { int s = 1; for (auto t : lst) s = mul(s, t); return s; } int pow_mod(int x, ll n) { int s = 1; for (; n; n >>= 1, x = mul(x, x)) if (n & 1) s = mul(s, x); return s; }
typedef std::vector<int> Poly; namespace Pol { const int N = 4200000; int a[N], b[N];
const int G = 3; const int Gi = pow_mod(G, p - 2); int P[N]; void init_P(int n) { int l = 0; while ((1 << l) < n) ++l; for (int i = 0; i < n; ++i) P[i] = (P[i >> 1] >> 1) | ((i & 1) << l - 1); } vector<int> init_W(int n) { vector<int> w(n); w[1] = 1; for (int i = 2; i < n; i <<= 1) { auto w0 = w.begin() + i / 2, w1 = w.begin() + i; int wn = pow_mod(G, (p - 1) / (i << 1)); for (int j = 0; j < i; j += 2) w1[j] = w0[j >> 1], w1[j + 1] = mul(w1[j], wn); } return w; } auto w = init_W(1 << 21); void DIT(int *a, int n) { for (int k = n >> 1; k; k >>= 1) for (int i = 0; i < n; i += k << 1) for (int j = 0; j < k; ++j) { int x = a[i + j], y = a[i + j + k]; a[i + j + k] = mul(add(x, p - y), w[k + j]), a[i + j] = add(x, y); } } void DIF(int *a, int n) { for (int k = 1; k < n; k <<= 1) for (int i = 0; i < n; i += k << 1) for (int j = 0; j < k; ++j) { int x = a[i + j], y = mul(a[i + j + k], w[k + j]); a[i + j + k] = add(x, p - y), a[i + j] = add(x, y); } int inv = pow_mod(n, p - 2); for (int i = 0; i < n; ++i) a[i] = mul(a[i], inv); reverse(a + 1, a + n); } Poly Mul(const Poly &A, const Poly &B, int n1, int n2) { int n = 1; while (n < n1 + n2 - 1) n <<= 1; init_P(n); for (int i = 0; i < n1; ++i) a[i] = add(A[i], p); for (int i = 0; i < n2; ++i) b[i] = add(B[i], p); fill(a + n1, a + n, 0), fill(b + n2, b + n, 0); DIT(a, n); DIT(b, n); for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]); DIF(a, n); Poly ans(n1 + n2 - 1); copy_n(a, n1 + n2 - 1, ans.begin()); return ans; } Poly MMul(const Poly &A, const Poly &B, int len) { int n = 1; while (n < 2 * len - 1) n <<= 1; init_P(n); for (int i = 0; i < len; ++i) a[i] = add(A[i], p); for (int i = 0; i < len; ++i) b[i] = add(B[i], p); fill(a + len, a + n, 0), fill(b + len, b + n, 0); reverse(b, b + len); DIT(a, n); DIT(b, n); for (int i = 0; i < n; ++i) a[i] = mul(a[i], b[i]); DIF(a, n); Poly ans(len); copy_n(a, len, ans.begin()); reverse(ans.begin(), ans.end()); return ans; } }
int fac[maxn], inv[maxn]; void init_C(int n) { fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = mul(fac[i - 1], i); inv[n] = pow_mod(fac[n], p - 2); for (int i = n - 1; ~i; --i) inv[i] = mul(inv[i + 1], i + 1); }
int n, m;
int pp[maxn], invpp[maxn], pre[maxn]; void solve(Poly &A, const Poly &B, int l, int r) { if (l + 1 == r) { if (!l) A[l] = 0; else { A[l] = add(mul(pp[l - 1], fac[n]), p - A[l]); if (l > n) A[l] = add(A[l], p - mul({ pre[l - n], pp[l - n], fac[n] })); pre[l] = add(pre[l - 1], mul(A[l], invpp[l])); } return ; } int m = l + r >> 1; solve(A, B, l, m); Poly t1(A.begin() + l, A.begin() + m), t2(B.begin(), B.begin() + r - l); Poly t = Pol::Mul(t1, t2, m - l, r - l); for (int i = m; i < r; ++i) A[i] = add(A[i], t[i - l]); solve(A, B, m, r); }
void work() { cin >> n >> m; pp[0] = invpp[0] = 1; for (int i = 1, inv = pow_mod(n, p - 2); i <= m; ++i) { pp[i] = mul(pp[i - 1], n); invpp[i] = mul(invpp[i - 1], inv); } Poly A(m + 1), B(m + 1); for (int i = 0; i < n; ++i) B[i] = fac[i]; solve(A, B, 0, m - n + 2); int ans = pp[m]; for (int i = 1; i <= m - n + 1; ++i) ans = add(ans, p - mul(A[i], pp[m - n + 1 - i])); cout << ans << "\n"; }
int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
int T; cin >> T; init_C(200000); while (T--) work(); return 0; }
|