题目描述
难度分:$2200$
输入$n$、$m(1 \leq n \leq m \leq 2 \times 10^5)$,长为$n$的字符串$s$,长为$m$的字符串$t$,只包含大写英文字母。
你可以随意重排$s$和$t$中的字母。然后从$t$中选一个长为$n$的子序列$t’$,使得对于每个$s[i]$,要么$s[i]=t’[i]$,要么$s[i]+1=t’[i]$。比如$s=$AAB
,$t’=$ABB
是合法的。
有多少个不同的$(s,t’)$二元组?模$998244353$。
输入样例$1$
3 4
AMA
ANAB
输出样例$1$
9
输入样例$2$
5 8
BINUS
BINANUSA
输出样例$2$
120
输入样例$3$
15 30
BINUSUNIVERSITY
BINANUSANTARAUNIVERSITYJAKARTA
输出样例$3$
151362308
输入样例$4$
4 4
UDIN
ASEP
输出样例$4$
0
算法
动态规划
这个题比较容易出的一个思路就是先固定$s$串不动,用$t$串中的字母来匹配$s$串中的字母。
状态定义
$f[i][j]$表示$s$串中位于字母表中排名第$i$的字母在消耗$t$串中位于字母表中排名第$i+1$的字母$j$个的情况下,能够产生多少种方案。在这个定义下,如果顺序考虑字母$i \in [1,26]$,最后一个字母z
就不存在下一个字母,$f[26][0]$就是答案,再乘上$s$串的排列数目$\frac{n!}{\Pi_{i \in [1,26]}cnta[i]}$就是最终答案,其中$cnta[i]$是$s$串中字母表排名第$i$的字母的频数,类似的定义一个$cntb$数组,$cntb[i]$是$t$串中字母表排名第$i$的字母的频数。
状态转移
对于字母$i$,枚举需要$t$串中的$j$个$i+1$来与之匹配。显然$j$最小就是$cnta[i]-cntb[i]$,即$t$串中所有的$i$都被用来匹配$s$串中的$i$,只有剩下的$cnta[i]-cntb[i]$个$i$需要用$t$串中的$i+1$来匹配;而$j$最多就是$min(cnta[i],cntb[i+1])$,即$s$串中的所有$i$都用$t$串中的$i+1$来匹配,但是又不能够超过$t$串中$i+1$的总数$cntb[i+1]$。
因此得到状态转移方程为
$f[i][j]=\Sigma_{k=0}^{cntb[i]-cnta[i]+j}f[i-1][k] \times C_{cnta[i]}^{j}$
组合数$C_{cnta[i]}^{j}$表示$s$串中要选$j$个出来和$t$串中的$i+1$字母匹配,方案数有$C_{cnta[i]}^{j}$个。而$f[i-1][k]$中就还需要枚举$k$,$k$从$0$开始,最多可以取到$cntb[i]-(cnta[i]-j)$,即$cntb[i]$中要有$cnta[i]-j$个与$s$串中的$i$匹配,不能超过这个数。
这样就有$O(26n)$的状态数量,而状态转移也是$O(n)$的,肯定会超时。但是注意到$\Sigma_{k=0}^{cntb[i]-cnta[i]+j}f[i-1][k]$是上一行状态的前缀和,所以用前缀和优化就可以$O(1)$转移了。
复杂度分析
时间复杂度
预处理出$s$串的字母频数表时间复杂度为$O(n)$,预处理出$t$串的字母频数表时间复杂度为$O(m)$。动态规划的状态数量是$O(n)$级别的,单次转移的时间复杂度为$O(1)$,动态规划的时间复杂度为$O(n)$。因此,整个算法的时间复杂度为$O(n+m)$。
空间复杂度
DP
数组$f$的空间消耗为$O(n)$,其前缀和数组$sum$也是这个规模。为了快速计算组合数(组合数是基于$O(n)$长度的字符串$s$),需要用到逆元,预处理出阶乘余数表及其对应逆元,空间消耗为$O(n)$。所以,整个算法的额外空间复杂度为$O(n)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, MOD = 998244353;
int n, m;
char s[N], t[N];
LL inv[N], finv[N], fac[N], f[30][N], sum[30][N];
// 预处理逆元
void get_inv(int n) {
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i++) {
inv[i] = (MOD - MOD/i) * inv[MOD % i] % MOD;
}
finv[0] = finv[1] = fac[0] = fac[1] = 1;
for(int i = 2; i <= n; i++) {
fac[i] = fac[i - 1] * i % MOD;
finv[i] = finv[i - 1] * inv[i] % MOD;
}
}
// 排列数
LL A(LL n, LL m) {
if(n == 0 || m == 0) return 1;
return fac[n] * finv[n - m] % MOD;
}
// 组合数
LL C(LL n, LL m) {
if(m == 0) return 1;
if(m < 0 || m > n) return 0;
return A(n, m) * finv[m] % MOD;
}
int main() {
scanf("%d%d", &n, &m);
scanf("%s", s + 1);
scanf("%s", t + 1);
get_inv(m);
int cnta[30] = {0}, cntb[30] = {0};
for(int i = 1; i <= n; i++) {
cnta[s[i] - 'A' + 1]++;
}
for(int i = 1; i <= m; i++) {
cntb[t[i] - 'A' + 1]++;
}
f[0][0] = 1;
for(int i = 1; i <= 26; i++) {
sum[i - 1][0] = f[i - 1][0];
for(int j = 1; j <= cntb[i]; j++) {
sum[i - 1][j] = (sum[i - 1][j - 1] + f[i - 1][j]) % MOD;
}
for(int j = max(0, cnta[i] - cntb[i]); j <= min(cnta[i], cntb[i + 1]); j++) {
f[i][j] = sum[i - 1][cntb[i] - (cnta[i] - j)] * C(cnta[i], j) % MOD;
}
}
LL ans = f[26][0];
ans = ans * fac[n] % MOD;
for(int i = 1; i <= 26; i++) {
ans = ans * finv[cnta[i]] % MOD;
}
printf("%d\n", ans);
return 0;
}