AC自动机:在Trie树上实现KMP
例题:
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 10010,M = 26,Max = 1000001;
int T ,n ,idx = 1,res;
char a[Max];
int q[N] ,ne[N * M] ,ch[N * M][M] ,en[N];
void insit()
{
idx = 1,res = 0;
memset(ch , 0 , sizeof ch);
memset(ne , 0 , sizeof ne);
memset(en , 0 , sizeof en);
}
void insert()
{
int u = 1,len = strlen(a + 1);
for (int i = 1; i <= len; i ++)
{
int x = a[i] - 'a';
if (!ch[u][x]) ch[u][x] = ++ idx;
u = ch[u][x];
}
en[u] ++;
}
void AC()
{
int hh = 0,tt = -1;
for (int i = 0; i < 26; i++) ch[0][i] = 1; // 建一个0号节点,并向下一个点连26条边(为后面取消while循环做优化准备)
q[++ tt] = 1,ne[1] = 0;
while (hh <= tt)
{
int u = q[hh ++];
for (int i = 0; i < 26; i++)
{
if (!ch[u][i]) ch[u][i] = ch[ne[u]][i]; // 指向和它前缀相同的点的子节点
else
{
q[++ tt] = ch[u][i]; // 先入队
int v = ne[u];
ne[ch[u][i]] = ch[v][i]; // 当前节点的子节点的ne的值就等于该节点的ne值所指向的节点的儿子
}
}
}
}
void find()
{
int k;
int u = 1,len = strlen(a + 1);
for (int i = 1; i <= len; i++)
{
int x = a[i] - 'a';
k = ch[u][x];
while (k > 1)
{
res += en[k];
en[k] = 0;
k = ne[k];
}
u = ch[u][x];
}
}
int main()
{
cin >> T;
while (T --)
{
cin >> n;
insit();
while (n --)
{
cin >> a + 1;
insert();
}
AC();
cin >> a + 1;
find();
cout << res << endl;
}
return 0;
}
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 202,M = 1000010,NM = 26;
int n ,idx = 1,cnt;
char a[N][M];
int ne[M * N] ,ch[M][NM] ,cnt1[N];
int q[M];
int sum[M] ,sum1[M];
void insert(int v)
{
int u = 1,len = strlen(a[v] + 1);
for (int i = 1; i <= len ; i ++)
{
int x = a[v][i] - 'a';
if (!ch[u][x]) ch[u][x] = ++ idx;
u = ch[u][x];
sum[u] ++;
}
cnt1[++ cnt] = u;
}
void AC()
{
int hh = 0,tt = -1;
for (int i = 0; i < 26; i ++) ch[0][i] = 1;
q[++ tt] = 1,ne[1] = 0;
while (hh <= tt)
{
int u = q[hh ++];
for (int i = 0; i < 26; i ++)
{
if (!ch[u][i]) ch[u][i] = ch[ne[u]][i];
else
{
q[++ tt] = ch[u][i];
int v = ne[u];
ne[ch[u][i]] = ch[v][i];
}
}
}
for (int i = idx; i >= 0; i --) sum1[q[i]] = sum[q[i]];
for (int i = idx; i >= 0; i --) sum1[ne[q[i]]] += sum1[q[i]];
}
int find(int v)
{
int k ,res = 0;
int u = 1,len = strlen(a[v] + 1);
for (int i = 1; i <= len; i ++)
{
int x = a[v][i] - 'a';
k = ch[u][x];
while (k > 1)
{
res += cnt1[k];
cnt1[k] = 0;
k = ne[k];
}
u = ch[u][x];
}
return res;
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i ++)
{
cin >> (a[i] + 1);
insert(i);
}
AC();
for (int i = 1; i <= n; i ++)
cout << sum1[cnt1[i]] << endl;
return 0;
}