哈希表
如果要储存和使用线性表(1,75,324,43,1353,90,46)一般情况下我们会使用一个数组 A[1..7] 来顺序存储这些数。
但是这样的存储结构会给查询算法带来 O(n) 的时间开销。
对 A 排序,使用二分查询法,时间复杂度变为 O(log n)也可以用空间换时间的做法,用数组 A[1..1353] 来表示每个数是否出现,查找的时间复杂度变为 O(1),但是空间上的开销变得巨大。
优化上一种做法,建立一个哈希函数 h(key) = key % 23.(1, 75, 324, 43, 1353, 90) -> (1, 6, 2, 20, 19, 21, 0)
我们只要用一个 A[0..22] 数组就可以快速的查询每个数是否出现。
这种线性表的结构就称为哈希表(Hash Table)。
可以看出,哈希表的基本原理是用一个下标范围比较大的数组 A 来存储元素。
设计一个函数 h,对于要存储的线性表的每个元素 node,取一个关键字 key,算出函数值 h(key) 然后把这个值作为下标,用 A[h(key)] 来存储 node。
最常见的 h 就是模函数,也就是选定一个 m,令 h(key) =key % m.
但是有一个问题,可能存在两个 key: k1, k2 使得h(k1)=h(k2),这时也称产生了“冲突”。
解决冲突有很多种办法:
1.开放寻址法
2.拉链法:可以让 A 的每个元素都存一个链表,对于
h(k1)=h(k2),我们可以让这两个 node 都接在 A[h(k1)]
的链表上
假设我们使用第二种方法解决冲突。
对于插入元素(node, key):
- 计算 h(key),把 node 插入 A[h(key)] 链表。
对于查询元素(node, key):
- 计算 h(key),如果 A[h(key)] 为空,说明 node 不存在。
否则遍历 A[h(key)] 链表,寻找 node。
#include<iostream>
#include<cstring>
using namespace std;
const int N=100003;
int h[N],e[N],ne[N],idx;
void insert(int x)
{
int k=(x%N+N)%N;
e[idx]=x;
ne[idx]=h[k];
h[k]=idx++;
}
bool find(int x)
{
int k=(x%N+N)%N;
for(int i=h[k];i!=-1;i=ne[i])
{
if(e[i]==x)
return true;
}
return false;
}
int main()
{
int n;
scanf("%d",&n);
memset(h,-1,sizeof(h));
while(n--)
{
char op[2];
int x;
scanf("%s%d",op,&x);
if(*op=='I')
insert(x);
else
{
if(find(x))
puts("Yes");
else
puts("No");
}
}
return 0;
}
例题
已知 X[1..4] 是 [-T, T] 中的整数,求出满足方程AX[1]+BX[2]+CX[3]+DX[4] = P的解有多少组?
|P|≤1e9,|A|,|B|,|C|,|D|≤1e4,T≤500
最直观的方法枚举 X[1..4], 时间复杂度 O($n^4$)
适当优化,枚举了X[1..3] 之后,实际上 X[4] 已经确定了,时间复杂度 O($n^3$)
继续优化,采用 meet in the middle 策略:
一边枚举 X[1], X[2]
一边枚举 X[4], X[3]
然后看有哪些方案可以组成方程的解
枚举 X[1], X[2], 然后算出 P-AX[1]-BX[2],把这个值存入一个哈希表,注意要统计次数。
这一步时间复杂度 O($n^2$)
然后枚举 X[3], X[4], 算出 CX[3]+DX[4],去哈希表里查找这个值出现了几次。
把次数加进答案,这一步时间复杂度 O($n^2$)
因此总的时间复杂度是 O($n^2$)
字符串哈希
假设有 n 个长度为 L 的字符串,问其中最多有几个字符串是相等的。
直接比较两个长度为 L 的字符串是否相等时间复杂度是O(L) 的。
因此需要枚举 O(n2) 对字符串进行比较,时间复杂度 O($n^2L$)
如果我们把每个字符串都用一个哈希函数映射成一个整数。
问题就变成了查找一个序列的众数。
时间复杂度变为了 O(nL)
一个设计良好的字符串哈希函数可以让我们先用 O(L) 的时间复杂度预处理,之后每次获取这个字符串的一个子串的哈希值都只要 O(1) 的时间。
BKDRHash 的基本思想就是把一个字符串当做一个 k 进制数来处理
假设字符串 s 的下标从 1 开始,长度为 n
ha[0] = 0;
for (int i = 1;i <= n;i ++)
ha[i] = (ha[i - 1] * P + str[i]) % M;
我们知道 ha[i] 就是 s[1..i] 的 BKDRHash
那么现在询问 s[x..y] 的 BKDRHash ,
因此我们预处理出 ha 数组和 k 的幂次,每次询问 s[x..y]的哈希值,只要 O(1) 的时间。
经验值,k=131,M=$2^{64}$
#include <iostream>
#include <algorithm>
using namespace std;
typedef unsigned long long ULL;
const int N = 100010, P = 131;
int n, m;
char str[N];
ULL h[N], p[N];
ULL get(int l, int r)
{
return h[r] - h[l - 1] * p[r - l + 1];
}
int main()
{
scanf("%d%d", &n, &m);
scanf("%s", str + 1);
p[0] = 1;
for (int i = 1; i <= n; i ++ )
{
h[i] = h[i - 1] * P + str[i];
p[i] = p[i - 1] * P;
}
while (m -- )
{
int l1, r1, l2, r2;
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
if (get(l1, r1) == get(l2, r2)) puts("Yes");
else puts("No");
}
return 0;
}
例题
acwing160
核心问题就是:
给定两个字符串 A,B。求出 A 的每个后缀子串和 B 的最长公共前缀。
标准做法是扩展 KMP,时间复杂度线性。
我们来用 Hash 试试看
前面已经提到,我们可以用 O(n)预处理 O(1)处理出一个子串的哈希值。
求字符串 A[i..n] 与字符串 B[1..m] 的最长公共前缀?
二分长度 mid
计算出 A[i..i+mid-1] 和 B[1..mid] 的哈希值,比较是否相等。
因此时间复杂度是 O(nlog n) 的
#include <iostream>
#include <algorithm>
using namespace std;
typedef unsigned long long ULL;
const int N = 200010, P = 131;
int n, m,q;
char a[N];
char b[N];
ULL ha[N],hb[N],p[N];
int cnt[N];
ULL get(ULL h[],int l, int r)
{
return h[r] - h[l - 1] * p[r - l + 1];
}
int main()
{
scanf("%d%d%d", &n, &m,&q);
scanf("%s", a + 1);
scanf("%s",b+1);
p[0] = 1;
for(int i=1;i<=max(n,m);i++)
p[i]=p[i-1]*P;
for(int i=1;i<=n;i++)
ha[i]=ha[i-1]*P+a[i];
for(int i=1;i<=m;i++)
hb[i]=hb[i-1]*P+b[i];
for(int i=1;i<=n;i++)
{
int l=0,r=min(m,n-i+1);
while(l<r)
{
int mid=l+r+1>>1;
if(get(ha,i,i+mid-1) == get(hb,1,mid))
l=mid;
else
r=mid-1;
}
cnt[l]++;
}
while(q--)
{
int x;
scanf("%d",&x);
printf("%d\n",cnt[x]);
}
return 0;
}
poj2774
给出两个字符串 S 和 T,求它们的最长公共子串。|S|, |T| ≤ 1e5
思路
原始的 DP 做法(最长公共子序列),时间复杂度 O(n^2)
我们可以二分答案 len。然后计算 S 和 T 中所有长度为 len 的子串的哈希值。
这一步是 O(n) 的。
然后比较 S 的哈希值集合中和 T 的哈希值集合中有没有相同的元素。
可以再通过一步哈希找相同的值。这样总共仍然是 O(n)。
总的时间复杂度就是 O(n log n)
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cstdio>
using namespace std;
const int N=100010,P=131;
char a[N],b[N];
typedef unsigned long long ULL;
ULL ha[N],hb[N],p[N];
vector<ULL> v;
int la,lb;
ULL get(ULL h[],int l,int r)
{
return h[r]-h[l-1]*p[r-l+1];
}
bool check(int mid)
{
v.clear();
for(int i=mid;i<=la;i++)
v.push_back(get(ha,i-mid+1,i));
sort(v.begin(),v.end());
for(int i=mid;i<=lb;i++)
{
ULL t=get(hb,i-mid+1,i);
if(binary_search(v.begin(),v.end(),t))
return true;
}
return false;
}
int main()
{
scanf("%s%s",a+1,b+1);
la=strlen(a+1);
lb=strlen(b+1);
p[0]=1;
for(int i=1;i<=max(la,lb);i++)
p[i]=p[i-1]*P;
for(int i=1;i<=la;i++)
ha[i]=ha[i-1]*P+a[i];
for(int i=1;i<=lb;i++)
hb[i]=hb[i-1]*P+b[i];
int l=0,r=1e5;
while(l<r)
{
int mid=l+r+1>>1;
if(check(mid))
l=mid;
else
r=mid-1;
}
printf("%d\n",l);
return 0;
}
也可手写哈希表:代码
Codeforces 580E
给出一个数字串,现在有两种操作:
1: l r d: 将[l,r] 中的所有数字都改为 d
2: l r d:询问[l,r]这个子串的周期是否为 d。 1 <= n <= 1e5
首先思考对于一个字符串 S,如何判断它的周期是不是 d?
比如串 ababab 的周期是 2 串 abcabc 的周期是 3 串 abcde 的周期是 5
假设 S 的长度为 n。
只要判断 S[1..n-d+1] 和 S[d+1..n]
是否相等即可。
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define lc u<<1
#define rc u<<1|1
int n,m,k;
const int N=100010,P=131,mod=1e9+7;
char s[N];
typedef long long LL;
struct Node
{
int l,r;
int tag;
LL key;
}tr[N<<2];
LL h[N],p[N];//p[i]表示p的i次方,h[i]表示连续i个1的字符串的hash值
void pushup(int u,int k)
{
tr[u].key=((tr[lc].key*p[k])%mod+tr[rc].key)%mod;
}
void pushdown(int u)
{
if(tr[u].tag)
{
int mid=tr[u].l+tr[u].r>>1,l=tr[u].l,r=tr[u].r;
tr[lc].key=(tr[u].tag*h[mid-l+1])%mod;
tr[lc].tag=tr[u].tag;
tr[rc].key=(tr[u].tag*h[r-mid])%mod;
tr[rc].tag=tr[u].tag;
tr[u].tag=0;
}
}
void build(int u,int l,int r)
{
if(l == r)
{
tr[u]={l,r,0,s[l]-'0'+1};
return;
}
tr[u]={l,r};
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
pushup(u,r-mid);
}
void modify(int u,int l,int r,int d)
{
if(l <= tr[u].l && tr[u].r <= r)
{
tr[u].key=(d*h[tr[u].r-tr[u].l+1])%mod;
//cout<<"-----"<<u<<' '<<tr[u].key<<endl;
tr[u].tag=d;
return;
}
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
if(l <= mid)
modify(lc,l,r,d);
if(r > mid)
modify(rc,l,r,d);
pushup(u,tr[u].r-mid);
}
int query(int u,int l,int r)
{
if(tr[u].l == l && tr[u].r == r)
return tr[u].key;
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
if(r<=mid)
return query(lc,l,r);
else if(l>mid)
return query(rc,l,r);
else
return (query(lc,l,mid)*p[r-mid]%mod+query(rc,mid+1,r))%mod;
}
void print(int u)
{
if(tr[u].key)
{
cout<<"--"<<u<<' '<<tr[u].l<<' '<<tr[u].r<<' '<<tr[u].key<<endl;
print(lc);
print(rc);
}
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
scanf("%s",s+1);
p[0]=1;
for(int i=1;i<=n;i++)
{
h[i]=(h[i-1]*P%mod+1)%mod;
p[i]=p[i-1]*P%mod;
}
build(1,1,n);
m+=k;
while(m--)
{
int t,l,r,d;
scanf("%d%d%d%d",&t,&l,&r,&d);
if(t == 1)
modify(1,l,r,d+1);//加一处理,防止d为0
else
{
if(d == r - l + 1)
{
printf("YES\n");
continue;
}
if(query(1,l,r-d) == query(1,l+d,r))
puts("YES");
else
puts("NO");
}
//print(1);
}
return 0;
}
坑:第75组数组卡无符号64位自动溢出的hash,防ull溢出的解决方法是取模
Codeforces 955D
给出两个字符串 S 和 T。你需要在 S 中选出两个不相交的长度为 k 的子串,使得他们拼起来之后包含 T。
|T| ≤ 2k ≤ |S| ≤ 5·1e5
样例: (答案:Yes 1 5)
7 4 3
baabaab
aaaa
KMP
我们用形式化的语言来进行描述。
假设现在 T[s+1..s+k] 和 P[1..k] 匹配上了。
此时 T[s+k+1] != P[k+1]。
朴素的做法是:回到 T[s + 2] 和 P[1] 重新开始比较。
KMP算法:找到一个最大的 x,使得 T[s+1..s+k] 的后 x个字符,和 P 的前 x 个字符相同。
这部分就是能匹配上的,我们可以不用逐个判断。
又注意到 T[s+1..s+k] = P[1..k]
那么我们要求的就是一个最大的 x,满足 P[1..k] 的前 x个字符等于它的后 x 个字符。当然 x 要小于 k。
这个 x 记为 next[k]
对于 P = ababaca
我们可以计算出next数组:
有了 next 数组,现在如何匹配两个字符串呢?
for (int i = 1, j = 0; i <= m; i ++ )
{
while (j && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j ++ ;
if (j == n)
{
printf("%d ", i - n);
j = ne[j];
}
}
讲过如何匹配之后,我们还要会高效计算 next 数组。
计算 next 数组的过程就是拿 P 和 P 自己匹配的过程。
只不过要在匹配的过程中,记录每个位置下指针指向的位置,作为 next 数组。
for (int i = 2, j = 0; i <= n; i ++ )
{
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j ++ ;
ne[i] = j;
}
KMP算法复杂度
时间复杂度:O(n)
空间复杂度:O(n)
例题
poj2406
给一个字符串 S,求 S 的一个最短的循环节 e,使得 S 可以写成 eee…eee (共|S|/|e|个 e)
输出 |S|/|e| 的最大值。|S| ≤ 1e6
样例:
abcd (答案:1)
aaaa (答案:4)
ababab (答案:3)
假设这个字符串的长度为 len
如果 len 可以被 len-next[len] 整除,那么我们就可以说len-next[len] 是那个循环节的长度。
因为 next[len] 就表示: S[1..next[len]] = S[next[len] + 1..len]
可以证明满足这一条性质的字符串具有长度为 len-next[len] 的循环周期
否则答案就是 1 了。
因为如果存在一个长度为 d 的循环节,那一定满足:
S[1..len-d+1] == S[d + 1 .. len]
但是现在循环节的长度只能是 len-next[len],如果它不是 len 的因子,那就没有可能了。
时间复杂度 O(n)
定理:假设S的长度为len,则S存在最小循环节,循环节的长度L为len-next[len],子串为S[0…len-next[len]-1]。
(1)如果len可以被len - next[len]整除,则表明字符串S可以完全由循环节循环组成,循环周期T=len/L。
(2)如果不能,说明还需要再添加几个字母才能补全。需要补的个数是循环个数L-len%L=L-(len-L)%L=L-next[len]%L,L=len-next[len]。
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=1e6+10;
char s[N];
int ne[N];
int n;
int main()
{
while(scanf("%s",s+1))
{
if(s[1] == '.')
break;
n=strlen(s+1);
for(int i=2,j=0;i<=n;i++)
{
while(j && s[i] != s[j+1])
j=ne[j];
if(s[i] == s[j+1])
j++;
ne[i]=j;
}
int t=n-ne[n];
if(n%t == 0)
cout<<n/t<<endl;
else
cout<<1<<endl;
}
return 0;
}
poj2752
给定一个字符串 S,求出 S 中所有的既是前缀又是后缀的子串。输出 i 代表 S[1..i]|S| ≤ 400000
样例:
ababcababababcabab(答案:2 4 9 18)
aaaaa(答案:1 2 3 4 5)
求出 next 数组。
答案就是 len, next[len], next[next[len]], …
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=4e5+10;
char s[N];
int ne[N];
int ans[N];
int main()
{
while(~scanf("%s",s+1))
{
int n=strlen(s+1);
for(int i=2,j=0;i<=n;i++)
{
while(j && s[i] != s[j+1])
j=ne[j];
if(s[i] == s[j+1])
j++;
ne[i]=j;
}
int t=ne[n],cnt=0;
while(t)
{
ans[cnt++]=t;
t=ne[t];
}
for(int i=cnt-1;i>=0;i--)
printf("%d ",ans[i]);
printf("%d\n",n);
}
return 0;
}
hdu2594
求出最长的既是 s1 的前缀又是 s2 的后缀的子串。
样例:
riemann
marjorie
答案:rie
只要把两个串拼起来,中间用一个分隔符$,s1$s2,令len=strlen(s1$s2),则next[len]即为答案
也可不加分隔符,枚举 len, next[len], next[next[len]] ......找到小于 min(|s1|, |s2|) 的最大值即可。
#include<iostream>
#include<cstring>
using namespace std;
const int N=50010*2;//数组开两倍!
char s1[N],s2[N];
int ne[N];
int main()
{
while(~scanf("%s%s",s1+1,s2+1))
{
strcat(s1+1,"$");
strcat(s1+1,s2+1);
int n=strlen(s1+1);
for(int i=2,j=0;i<=n;i++)
{
while(j && s1[i] != s1[j+1])
j=ne[j];
if(s1[i] == s1[j+1])
j++;
ne[i]=j;
}
int ans=ne[n];
s1[ans+1]='\0';
if(ans == 0)
printf("0\n");
else
printf("%s %d\n",s1+1,ans);
}
}
Codeforces 526D
给出一个字符串 s,判断其每个前缀是否可以表示成ABAB…ABA 的形式(A 和 B 都可以为空, 但是必须满足 A 有 k+1 个,B 有 k 个)
|s|, k ≤ 1e6
输入
7 2
bcabcab
输出
0000011
长度为 6 的前缀,可以取 A=“”,B=“bca”
长度为 7 的前缀,可以取 A=“b”, B=“ca”
对于前缀 P,我们可以把 P 拆成 SSSS…ST,其中 T 是 S的前缀。
这样就可以用 KMP 来做了。
首先 i-next[i] 就是S[1..i] 这一段的最小循环节的长度,
记为 e。
可以发现 e 的倍数 je 也是循环节。
#include<iostream>
#include<cmath>
#include<cstring>
using namespace std;
const int N=1e6+10;
char s[N];
int n,k;
int ne[N];
bool check(int i,int cir)
{
int up=i/k/cir;
int down=ceil((i/(k+1)+1)/(1.0*cir));
return up>=down || (i % (k+1) == 0 && (i / (k+1)) % cir == 0);
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
scanf("%s",s+1);
int len=strlen(s+1);
for(int i=2,j=0;i<=len;i++)
{
while(j && s[i] != s[j+1])
j=ne[j];
if(s[i] == s[j+1])
j++;
ne[i]=j;
}
for(int i=1;i<=n;i++)
{
int cir=i-ne[i];
printf("%d",check(i,cir));
}
}
}
acwing
首先我们用KMP求出 T 的 next数组。
利用 next数组在长文本中匹配模板串 T的过程:如果下一个字母不匹配,需要一直沿着next指针找:j = next[j],直到下一个字符匹配或者next指针指向开头为止。
然后我们会发现,假设我们已经匹配完长文本的前个字母,则剩下部分的匹配过程,只跟next指针的位置有关,因此我们可以用二维数组来表示当前状态的方案数:
f[i][j]表示匹配完前 i 个字母时,next指针在 j 时的方案总数。
状态转移:对于每个状态f[i][j],我们从’a’-‘z’枚举下一个字母,然后求出对应的next指针,假设是,则将f[i][j]的方案总数累加到f[i+1][u]。
转移过程中需要注意,因为密码中不能存在 T,所以next指针要避免转移到 T的最后一个字母。
#include<iostream>
#include<cstring>
using namespace std;
const int N=55,mod=1e9+7;
int f[N][N];
int edge[N][26];
char p[N];
int ne[N];
int n;
int main()
{
cin>>n>>p+1;
int m=strlen(p+1);
for(int i=2,j=0;i<=m;i++)
{
while(j && p[i] != p[j+1])
j=ne[j];
if(p[i] == p[j+1])
j++;
ne[i]=j;
}
for(int i=0;i<m;i++)
for(char k='a';k<='z';k++)
{
int j=i;//已匹配了前i个字母
while(j && p[j+1] != k)
j=ne[j];
if(p[j+1] == k)
j++;
edge[i][k-'a']=j;
}
f[0][0]=1;
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
for(char k='a';k<='z';k++)
{
int u=edge[j][k-'a'];
f[i+1][u]=(f[i+1][u]+f[i][j])%mod;
}
int res = 0;
for (int i = 0; i < m; i ++ ) res = (res + f[n][i]) % mod;
cout << res << endl;
return 0;
}
luoguP3193
设f(i,j)表示以准考证号为基准递推,准考证号匹配到第i位,不吉利数字匹配到第j位时(即准考证的后j位等于不吉利数字的前j位),不出现不吉利数字的字符串数量。。
那怎么转移?
既然要以考号为基准递推,就要考虑一位考号对下一位的影响。
而考号是可以随便写的,那我们就要考虑10种数字了。
对于一个新数字new,有以下几种情况:
1.new
与不吉利数字的j+1位匹配,dp(i+1,j+1)的答案数+dp(i,j)
2.上述两者不匹配。
不匹配怎么办?这个不匹配的new一定没有贡献了?
当然不一定。
怎么讲呢,举个例子吧,不吉利数字是12212112。然后你的准考证号枚举到第9位时,前面8位已经枚举成了11112212。可以发现已经匹配了不吉利数字的前5位,现在要匹配第6位。
如果第9位枚举1,那它就匹配了,dp(9,6)+=dp(8,5)。如果第9位枚举2,那它就不匹配。但是会发现,存在 与当前已匹配的不吉利数字的后缀相同 的前缀,可以匹配上这个2!
上一个位置就是不吉利数字的第2位。
它的下一位,第3位2,刚好可以匹配枚举的第9位!
所以此时最多能匹配不吉利数字的前3位,有转移dp(9,3)+=dp(8,5)。
我们只需要沿着不吉利数字的失配指针往前走,找到第一个下一位与new匹配的位置就可以了。
为什么不用考虑再往前的下一位可以匹配new的位置?
因为这是递推,从前往后每一种状态都会被考虑,所以在考虑匹配后面的位之前,前面的位已经匹配好更前面的情况了。
比如不吉利数字1221221211。你目前枚举的准考证号前8位是12212212,现在你要枚举第9
位。很明显当枚举2时,通过找不吉利数字中 后缀相同的前缀,可知dp(8,8)可以转移到dp(9,6)。
但是我们发现也可以转移到dp(9,3)诶!
事实上,在这之前dp(8,5)已经转移到dp(9,3)过了。而dp(8,5)表示什么?它表示准考证号枚举8位,后5位与不吉利数字的前5位匹配上。
dp(8,8)同理,表示准考证号枚举前8位,后8位与不吉利数字的前8位匹配上。
这样直观看起来没什么答案关联。但是仔细考虑以下,不吉利数字的第8位的失配指针指向第5位。
这说明什么?
不吉利数字的前8位中,前5位等于后5位!
匹配5位的情况包含匹配8位的情况!(因为你匹配了8位,根据上推论可知也算在匹配了5位的情况中)
这就是别人博客中此题题解经常提到的计数方案会包含的问题。
其实按照AC自动机的构造方式,如果一个点没有字符为new的儿子边的话,它会建出一条对应的虚拟儿子边和点,儿子点上存的是沿着失配指针往回走的上一个实际存在这条字符边所指向的儿子。
当然这题m很小,不吉利数字串自己匹配自己的复杂度很小,可以直接暴力跑失配指针找第一个。
其中g(i,j)表示当前匹配到i个字符,添加一个字符变成匹配为j个字符的方案数。
不吉利数字已经知道了,g数组可以预处理出来(KMP)。
然后我们惊奇地发现n≤1e9,不让你循环推,直接就想到矩阵快速幂优化了。
观察一下转移方程,发现dp[i+1][?]总是由dp[i][?]推来,而且i还是n这个级别的。
又发现每次实际上都是乘一个固定的矩阵g,也就是说整个dp数组的某一位的值其实都是通过一些g数组乘过来的。
所以可以把dp数组直接当成g数组自己乘自己。
比如说,g(4,2)=4。
表示有4种转移情况可以让 与不吉利数字的前4位匹配的情况 在准考证号增加1位后 与不吉利数字的前2位匹配。
所以dp(x+1,2)=dp(x,4)∗g(4,2)=dp(x,4)∗4
而dp(x,4)又是哪来的?
它是通过$\sum_{i=1}^m$dp(x−1,i)∗g(i,4)转移过来的。
所以一直往前推到x=0,发现dp值的n次转移都只跟转移矩阵g有关,是否与不吉利数字完全匹配等情况可以在弄g数组时就处理掉,即g数组只考虑不吉利数字被匹配0~m−1位的情况,不让它转移到j位都被匹配的情况。
所以就是对g矩阵做快速幂,求它的n次方。
时间复杂度O(log(n)∗m^3)
为什么答案是g(0,i),for i 0~m-1
考虑g(i,j)的意义。如果g矩阵自己对自己连续进行k次转移(不进行快速幂而循环推),g(i,j)就表示:在进行k次转移前 准考证号匹配不吉利数字的前i位时,准考证号增加k个字符后,使不吉利数字沿失配指针(自己也可以)找到的最大的匹配位数j(算上新匹配的k位)的方案数。这是矩阵转移的基本概念。
所以进行n次转移后,g(i,j)就表示一开始准考证号匹配不吉利数字的前i位时,准考证号增加n个字符后,使不吉利数字沿失配指针(自己也可以)找到的最大的匹配位数j(算上新匹配的k位)的方案数。
我们需要取全局情况的答案,而这是很显然的。开始时准考证号匹配不吉利数字的前0位(准考证号还一位都没枚举),所以i为0;而由于已经定义过g数组只考虑不吉利数字被匹配 0~m−1位的情况,所以j在这个区间取任意值,g(0,j)都是答案的一部分。把它们都算上就是答案。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=25;
char p[N];
int n,m,mod;
struct matrix
{
int m[N][N];
matrix()
{
memset(m,0,sizeof m);
}
};
int ne[N];
matrix mul(matrix a,matrix b)
{
matrix res;
for(int i=0;i<m;i++)
for(int j=0;j<m;j++)
for(int k=0;k<m;k++)
res.m[i][j]=(res.m[i][j]+a.m[i][k]*b.m[k][j])%mod;
return res;
}
int main()
{
scanf("%d%d%d", &n, &m, &mod);
scanf("%s", p+1);
matrix g;
for(int i=2,j=0;i<=m;i++)
{
while(j && p[i] != p[j+1])
j=ne[j];
if(p[i] == p[j+1])
j++;
ne[i]=j;
}
for(int i=0;i<m;i++)
for(int k=0;k<10;k++)
{
int j=i;
while(j && p[j+1]-'0' != k)
j=ne[j];
if(p[j+1]-'0' == k)
j++;
if(j<m) g.m[i][j]=(g.m[i][j]+1)%mod;
}
matrix a;
for(int i=0;i<m;i++)
a.m[i][i]=1;
while(n)
{
if(n & 1) a=mul(a,g);
g=mul(g,g);
n>>=1;
}
//
// for(int i=0;i<m;i++)
// {
// for(int j=0;j<m;j++)
// cout<<a.m[i][j]<<' ';
// cout<<endl;
// }
int ans = 0;
for(int i = 0; i < m; ++i)
ans = (ans + a.m[0][i]) % mod;
printf("%d\n", ans);
return 0;
}
高能预警!!!