题目描述
对于一个长度不小于 $2$ 的字符串,我们可以对其使用重组操作。
重组操作分为两个步骤:
1. 将字符串从中间任意位置截断,得到前后两个非空子串。
2. 将两个子串交换位置(前变后,后变前)后重新连接,得到新字符串。
例如,对于字符串abcdef
,一种可行的重组操作为:
- 将字符串截断为两个非空子串
ab
、cdef
。 - 将两个子串交换位置后重新连接,得到新字符串
cdefab
。
给定一个起始字符串 $s_1$ 和一个目标字符串 $s_2$。
我们希望对 $s_1$ 进行恰好 $k$ 次重组操作后得到 $s_2$。
请你计算,一共有多少种不同的可行方案。
由于答案可能很大,请你输出对 $10^9+7$ 取模后的结果。
对于两种可行方案,如果存在至少一个 $i(1\leq i\leq k)$ 满足,在第 $i$ 次操作中,两种方案的截断位置不同,则认为两种可行方案为不同方案。
解题思路
注意到以下结论:
对一个字符串先从
i
位置切开,再从j
位置切开,等价于从(i + j - 1) % len + 1
处切开。
对上面这个解释一下,比如一个长度为 $114$ 的字符串,如果从第 $i=105$ 个字符后面切了一下,又从第 $j=74$ 个字符后面切了一下,等价于从第 $65$ 个字符后面切开。
通过这种方式,可以将多次切割转换为 $1$ 次切割。这样题目的解法就出来了:
- 观察
s1
与s2
,看看在只进行 $1$ 次切割的情况下,有哪些落刀点可以选。 - 对于每一个落刀点,将其展开为 $k$ 次切割 $(i_1,i_2,\cdots,i_k)$,有多少种展开方案。
- 对所有的方案数求和并取模。
参考算法
上面的步骤 2 存在技巧。暴力计算肯定行不通,这里直奔主题,提供一个时间复杂度稍微低一点的步骤 2 实现。
倍增$\quad$令dp[i][j]
为“将 $1$ 次落刀点为 $j$ 的切割展开为 $2^i$ 次切割的不同方案数”。显然有状态转移方程:
$$dp[i][j] = \sum_{t=1}^l dp[i - 1][t]+dp[i - 1][\hat t]\quad\mathrm{s.t.}\ 1\leq \hat t\leq l\mathrm{\ and\ }t+\hat t=j(\operatorname{mod}l)$$
这里 $l$ 是字符串长度,我们也可以直接得出 $\hat t=(j-t+l-1)\operatorname{mod}l+1$。
根据状态转移方程求出数组 $dp[\log k][l]$ 需要时间 $O(l^2\log k)$。
拆解$\quad$接下来我们将 $k$ 拆成 $\sum 2^p$ 的形式以利用我们的 $dp$。比如
$$k=114=64+32+16+2=2^6+2^5+2^4+2^1$$
后续代码中seq
就是存这些指数的(上例,$\mathit{seq}=[6,5,4,1]$)。从而且 $k$ 刀,被我们简化成了切 $\log k$ 刀。
再 dp$\quad$再次执行一下dp
,dpp[i][j]
表示“将 $1$ 次落刀点为 $j$ 的切割展开为 $\sum_{t=0}^i 2^{\mathit{seq}[t]}$ 次切割的不同方案数”。状态转移方程也很简单:
$$dpp[i][j] = \sum_{t=1}^l dpp[i - 1][t]+dp[\mathit{seq}[i]][\hat t]\quad\mathrm{s.t.}\ 1\leq \hat t\leq l\mathrm{\ and\ }t+\hat t=j(\operatorname{mod}l)$$
最终结果$\quad$置ans = 0
然后对于每一个可行的切割点(对应步骤 3)i
,执行ans = (ans + dpp[cnt - 1][i])
即可。
时间复杂度
$O(l^2\log k)$
C++ 代码
奇丑无比的码风
# include <iostream>
# include <cstring>
# define LLMUL(a,b) ((long long)(a) * (long long)(b))
using namespace std;
const int mod = 1e9 + 7;
char s1[1005],s2[1005];
int k,len,dp[20][1005],dpp[20][1005];
long long tmp,ans;
int seq[20],cnt;
bool ok(int pos){
for(int i = 0;i < len;i++)
if(s2[i] != s1[(i + pos + 1) % len])
return false;
return true;
}
int main(){
cin >> s1 >> s2 >> k;
len = strlen(s1);
if(k == 0)
return cout << (strcmp(s1,s2)?0:1),0;
for(int i = len - 1;i > 0;i--)
dp[0][i] = 1;
for(int i = 1;(1 << i) <= k;i++)
for(int j = 1;j <= len;j++)
for(int t = 1;t <= len;t++){
tmp = LLMUL(dp[i - 1][t],dp[i - 1][(j - t + len - 1) % len + 1]) % mod;
dp[i][j] = (dp[i][j] + tmp) % mod;
}
for(int i = 19;i >= 0;i--)
if(k >= (1 << i)){
k -= (1 << i);
seq[cnt++] = i;
}
for(int j = 1;j <= len;j++)
dpp[0][j] = dp[seq[0]][j];
for(int i = 1;i < cnt;i++)
for(int j = 1;j <= len;j++)
for(int t = 1;t <= len;t++){
tmp = LLMUL(dpp[i - 1][t],dp[seq[i]][(j - t + len - 1) % len + 1]) % mod;
dpp[i][j] = (dpp[i][j] + tmp) % mod;
}
for(int i = 0;i < len;i++)
if(ok(i)) ans = (ans + dpp[cnt - 1][i + 1]) % mod;
return cout << ans,0;
}