题目描述
给你一个字符串,问这个字符串的有多少个子串出现了k次
算法1
(后缀数组 + RMQ) $O(n * logn)$
对于每一个后缀查找排名在该后缀前面k名的所有后缀的最长前缀d,然后查询区间两端和相邻排名的后缀的最长前缀,由于题目要求精确出现k次,所以以该后缀为起点的子串和两端重合的部分要减去,即d减去两端最长前缀的最小值即为以该后缀为起点的子串且恰好出现k次的数目;
时间复杂度
参考文献
C++ 代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<cmath>
#include<string>
#include<cstring>
#include<bitset>
#include<vector>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<iomanip>
#include<algorithm>
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"<<endl;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define dbg(x) cerr << #x " = " << (x) << endl
#define endl "\n"
#define int long long
#define x first
#define y second
//CLOCKS_PER_SEC clock()函数每秒执行次数
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 1e5+5,M = N * 2;
const double PI = acos(-1);
typedef pair<int,int> PII;
int mod = 1e9 +7;
int n,m,k,S,T;
int x[N],y[N],c[N],rk[N],sa[N],h[N];
char s[N];
void get_sa(){
for(int i = 1 ; i <= m ; ++i ) c[i] = 0;
for(int i = 1 ; i <= n ; ++i ) c[x[i] = s[i]]++;
for(int i = 2 ; i <= m ; ++i ) c[i] += c[i - 1];
for(int i = n ; i ; i--) sa[c[x[i]]--] = i;
for(int k = 1 ; k <= n ; k <<= 1){
int num = 0;
for(int i = n - k + 1 ; i <= n ; ++i ) y[++num] = i;
for(int i = 1 ; i <= n ; ++i ){
if(sa[i] > k){
y[++num] = sa[i] - k;
}
}
for(int i = 1 ; i <= m ; ++i ) c[i] = 0;
for(int i = 1 ; i <= n ; ++i ) c[x[i]]++;
for(int i = 2 ; i <= m ; ++i ) c[i] += c[i - 1];
for(int i = n ; i ; i--) sa[c[x[y[i]]]--] = y[i],y[i] = 0;
swap(x,y);
x[sa[1]] = 1,num = 1;
for(int i = 2 ; i <= n ; ++i ){
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
}
if(num == n) break;
m = num;
}
}
void get_h(){
for(int i = 1 ; i <= n ; ++i ) rk[sa[i]] = i;
for(int i = 1,k = 0; i <= n; ++i ){
if(rk[i] == 1) continue;
if(k) k--;
int j = sa[rk[i] - 1];
while(j + k <= n && i + k <= n && s[i + k] == s[j + k]) k++;
h[rk[i]] = k;
}
}
int f[N][22];
int p[N];
void init(){
for(int i = 1 ; i <= n ; ++i ) f[i][0] = h[i];
for(int i = 1 ; i < 22 ; ++i ){
for(int j = 1 ; j + (1 << i) - 1 <= n ; ++j ){
f[j][i] = min(f[j][i - 1],f[j + (1 << i - 1)][i - 1]);
}
}
}
int get(int l,int r){
if(l > r) return 0;
if(l == r) return h[l];
int t = 0,d = r - l + 1;
while(p[t + 1] <= d) t++;
return min(f[l][t], f[r - (1 << t) + 1][t]);
}
void solve(){
scanf("%lld %s",&k,s + 1);
n = strlen(s + 1);
m = 122;
memset(x,0,sizeof(x));
get_sa();
get_h();
h[n + 1] = 0;
init();
int ans = 0;
for(int i = k ; i <= n ; ++i){
int l = i - k + 1,r = i;
int d = get(l + 1,r);
//需要特判长度为1的情况
if(l == r) d = n - sa[i] + 1;
int t = max(0ll,min(d - h[l],d - h[r + 1]));
if(d) ans += t;
}
printf("%lld\n",ans);
}
signed main(){
int tt;
p[0] = 1;
for(int i = 1 ; i <= 22 ; ++i) p[i] = p[i - 1] * 2;
scanf("%lld",&tt);
while(tt--)
solve();
return 0;
}
/*
*
* ┏┓ ┏┓+ +
* ┏┛┻━━━┛┻┓ + +
* ┃ ┃
* ┃ ━ ┃ ++ + + +
* ████━████+
* ◥██◤ ◥██◤ +
* ┃ ┻ ┃
* ┃ ┃ + +
* ┗━┓ ┏━┛
* ┃ ┃ + + + +Code is far away from
* ┃ ┃ + bug with the animal protecting
* ┃ ┗━━━┓ 神兽保佑,代码无bug
* ┃ ┣┓
* ┃ ┏┛
* ┗┓┓┏━┳┓┏┛ + + + +
* ┃┫┫ ┃┫┫
* ┗┻┛ ┗┻┛+ + + +
*/
算法2
(后缀自动机) $O(n)$
模板直接统计出现k次的子串数目
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<cmath>
#include<string>
#include<cstring>
#include<bitset>
#include<vector>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<iomanip>
#include<algorithm>
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"<<endl;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define dbg(x) cerr << #x " = " << (x) << endl
#define endl "\n"
#define LL long long
#define x first
#define y second
//CLOCKS_PER_SEC clock()函数每秒执行次数
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 2e5+5,M = N * 2;
const double PI = acos(-1);
typedef pair<int,int> PII;
int mod = 1e9 +7;
int n,m,k,S,T;
int tot = 1, last = 1;
struct Node
{
int len, fa;
int ch[26];
}node[N];
char str[N];
int f[N];
LL ans;
void extend(int c)
{
int p = last, np = last = ++ tot;
f[tot] = 1;
node[np].len = node[p].len + 1;
for (; p && !node[p].ch[c]; p = node[p].fa) node[p].ch[c] = np;
if (!p) node[np].fa = 1;
else
{
int q = node[p].ch[c];
if (node[q].len == node[p].len + 1) node[np].fa = q;
else
{
int nq = ++ tot;
node[nq] = node[q], node[nq].len = node[p].len + 1;
node[q].fa = node[np].fa = nq;
for (; p && node[p].ch[c] == q; p = node[p].fa) node[p].ch[c] = nq;
}
}
}
int h[N],ne[N],e[N],idx;
void add(int a,int b){
e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}
void init(){
tot = last = 1;
idx = 0;
memset(node,0,sizeof(node));
memset(f,0,sizeof(f));
memset(h,-1,sizeof(h));
}
void dfs(int u){
for(int i = h[u] ; ~i ; i = ne[i]){
int j = e[i];
dfs(j);
f[u] += f[j];
}
if(f[u] == k) ans += node[u].len - node[node[u].fa].len;
}
void solve(){
cin >> k;
init();
cin >> str;
n = strlen(str);
for(int i = 0 ; i < n ; ++i) extend(str[i] - 'a');
for(int i = 2 ; i <= tot ; ++i) add(node[i].fa,i);
ans = 0;
dfs(1);
cout << ans << endl;
}
signed main(){
IOS;
int tt;
cin >> tt;
while(tt--)
solve();
return 0;
}
/*
*
* ┏┓ ┏┓+ +
* ┏┛┻━━━┛┻┓ + +
* ┃ ┃
* ┃ ━ ┃ ++ + + +
* ████━████+
* ◥██◤ ◥██◤ +
* ┃ ┻ ┃
* ┃ ┃ + +
* ┗━┓ ┏━┛
* ┃ ┃ + + + +Code is far away from
* ┃ ┃ + bug with the animal protecting
* ┃ ┗━━━┓ 神兽保佑,代码无bug
* ┃ ┣┓
* ┃ ┏┛
* ┗┓┓┏━┳┓┏┛ + + + +
* ┃┫┫ ┃┫┫
* ┗┻┛ ┗┻┛+ + + +
*/