题意
给定两个字符串 s,p,给定 p 在 s 上的 k 个起始位置 xi,用 p 多次覆盖到 s 上,但 p 在 s 上相同位置的覆盖不能出现冲突,求这样的 s 有多少个。答案 mod 1e9+7。
KMP 或 字符串哈希
解法
如果 p 在 s 上的所有覆盖都合法,那么只要统计没有被覆盖的点的数量 cnt 即可,答案就是 26cnt。如果发生了重叠部分的冲突,答案就是 0。本题的关键在于如何判定重叠部分是否发生冲突,具体如下。
约定:p 的长度为 n,s 的长度为 m,共有 k 个起始位置 xi,cnt 表示已经被覆盖的点数,则 m−cnt 表示未被覆盖的点数。
我们只需要考虑相邻的两个的起始位置 xi,xi−1,如果 s 的子串 [xi−1,xi−1+n−1] 与 [xi,xi+n−1] 没有发生重叠,则 cnt 直接加上 n;如果发生了重叠,那么重叠部分即 s 的子串 [xi,xi−1+n−1],重叠长度 len=xi−1+n−xi。这一部分,对于不变的 p 串来说,分别是长度为 len 的 p 的前缀和后缀。即只要判断 p[1,len] 是否等于 p[n−len+1,n]。如果不冲突,那么 cnt 加上 n−len,否则返回 0。
针对 p 的自身前后缀匹配的问题,可以考虑 KMP 或者字符串哈希。
KMP
先预处理 Next 数组,且现在已知重叠长度为 len。下面所说的前缀/后缀均表示真前缀/后缀。
由 Next[i] 含义:以 i 结尾的能够匹配的最大前缀长度,Next[n] 即表示串 p 的后缀所能匹配的最长前缀长度。有三种情况:
- Next[n]<len:一定冲突
- Next[n]=len:一定无冲突
- Next[n]>len:此时需要不断回退 Next 指针,即不断获取以 n 结尾能匹配的所有前缀长度。如果某一次回退到恰好相等,才能保证无冲突,否则冲突。
核心代码:Line 38~53
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1000010, mod = 1e9 + 7;
int n, m, k;
char p[N];
int ne[N];
int x[N];
int ksm(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = (LL)res * a % mod;
a = (LL)a * a % mod;
k >>= 1;
}
return res;
}
int main() {
cin >> m >> k >> p + 1;
for (int i = 0; i < k; i++) scanf("%d", &x[i]);
n = strlen(p + 1);
for (int i = 2, j = 0; i <= n; i++) {
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
int cnt = 0;
for (int i = 0; i < k; i++) {
if (!i) cnt += n;
else {
int len = x[i - 1] + n - x[i];
if (len > 0) {
bool flag = false;
int j = ne[n];
while (j) {
if (j < len) break;
if (j == len) {
flag = true;
cnt += n - len;
break;
}
j = ne[j];
}
if (!flag) {
puts("0");
return 0;
}
} else {
cnt += n;
}
}
}
printf("%d\n", ksm(26, m - cnt));
return 0;
}
字符串哈希
用哈希来判定 p 串前后缀是否匹配,用 unsigned long long
自然溢出会被卡…换了一个模数 998244353 才行。
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ULL;
typedef long long LL;
const int N = 1000100, base = 131, mod = 1e9 + 7;
const int MOD = 998244353;
int n, m, k;
char str[N];
LL h[N], p[N];
int x[N];
int ksm(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = (LL)res * a % mod;
a = (LL)a * a % mod;
k >>= 1;
}
return res;
}
LL get(int l, int r) {
return ((h[r] - h[l - 1] * p[r - l + 1]) % MOD + MOD) % MOD;
}
int main() {
cin >> m >> k >> str + 1;
for (int i = 0; i < k; i++) scanf("%d", &x[i]);
n = strlen(str + 1);
p[0] = 1;
for (int i = 1; i <= n; i++) {
p[i] = p[i - 1] * base % MOD;
h[i] = (h[i - 1] * base + str[i] - 'a' + 1) % MOD;
}
int cnt = 0;
for (int i = 0; i < k; i++) {
if (!i) cnt += n;
else {
int len = x[i - 1] + n - x[i];
if (len > 0) {
if (get(n - len + 1, n) != get(1, len)) {
puts("0");
return 0;
}
cnt += n - len;
} else {
cnt += n;
}
}
}
printf("%d\n", ksm(26, m - cnt));
return 0;
}
%%%%%