啥都不会,看第一篇题解学的。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 15, mod = 998244353;
int n, k, fac[N], inv[N];
int qmi(int a, int k) {
int res = 1ll;
while (k) {
if (k & 1) res = res * 1ll * a % mod;
a = a * 1ll * a % mod, k >>= 1;
}
return res;
}
void init() {
fac[0] = 1; for (int i = 1; i <= 1000000; i++) fac[i] = fac[i - 1] * 1ll * i % mod;
inv[1000000] = qmi(fac[1000000], mod - 2);
for (int i = 999999; i >= 0; i--) inv[i] = (inv[i + 1] * 1ll * (i + 1) ) % mod;
}
inline int C(int n, int m) { return (m < 0) ? 0 : fac[n] * 1ll * inv[m] % mod * 1ll * inv[n - m] % mod; }
int L, R;
long long sum = 0, ans = 0;
void chk(long long &x) { x = (x % mod + mod) % mod; }
inline void add(int n, int m) { sum += C(n, m), chk(sum); }
inline void del(int n, int m) { sum -= C(n, m), chk(sum); }
int main() {
init();
scanf("%d%d", &n, &k);
if (n == k) return puts("1"), 0;
L = R = 0, sum = 1;
for (int p = 0; p <= k; p++)
if (2 * p < n) {
if (p) sum = (sum << 1) - C(p - 1, R) + C(p - 1, L - 1), chk(sum); //往下移动一层
int lim = min(k - p, n - 2 * p - 1), l = k - p - lim, r = k - p;
l = min(l, p), r = min(r, p);
while (L > l) add(p, --L); while (L < l) del(p, L++);
while (R < r) add(p, ++R); while (R > r) del(p, R--);
(ans += sum) %= mod;
} else (ans += C(n - p - 1, k - p)) %= mod;
printf("%lld\n", ans);
return 0;
}