题目描述
难度分:$2200$
输入一个长度$\leq 5000$的字符串$s$,只包含大写英文字母。
把$s$划分成若干段(子串),这些子串从左到右,字典序必须是严格递增的(不能相同)。
输出最多能划分成多少段,以及具体划分方案。如果划分方案不止一种,输出任意一种。
输入样例$1$
ABACUS
输出样例$1$
4
A
BA
C
US
输入样例$2$
AAAAAA
输出样例$2$
3
A
AA
AAA
输入样例$3$
EDCBA
输出样例$3$
1
EDCBA
算法
后缀最值优化DP
状态定义
$dp[i][j]$表示拆分$s$的后缀$[i,n]$,且开头一段为$[i,j]$的情况下,最多能被拆分多少段。在这个定义下,答案就应该是$max_{k \in [1,n]}dp[1][k]$,因此初始化$dp$值为$1$。
状态转移
对于一个状态$dp[i][j]$,如果存在一个$k \gt j$,满足$s[i…j] \lt s[j+1…k]$,那就有状态转移$dp[i][j]=dp[j+1][k]+1$,找到能使$dp[i][j]$最大的$k$即可。
但是状态数目为$O(n^2)$,如果枚举$k$那么单次转移就会是$O(n)$,会超时。注意到$k \gt j$是肯定成立的,所以维护一个后缀$dp$最大值对应的$k$,就可以$O(1)$转移了。即构建一个$sufmax$数组,$sufmax[i][j]$表示拆分后缀$[i,n]$,且第一个段的右端点$\geq j$的情况下能够得到最大$dp$值对应的右端点索引。再记录$nxt[i][j]$为$dp[i][j]$的转移来源就可以还原出具体方案。
那么就还剩一个问题,如何$O(1)$判断两个子串的字典序。可以借用后缀数组的$lcp$数组,$lcp[i][j]$表示后缀$s[i…n]$和$s[j…n]$的最长公共前缀长度,有了它就可以$O(1)$找到两个后缀的第一个字母不相同的位置,从而判断两个子串的字典序。正因为要使用后缀数组的辅助数组,所以将DP
的状态设计为拆分后缀。
复杂度分析
时间复杂度
状态数量为$O(n^2)$,单次转移的时间复杂度为$O(1)$,所以整个算法的时间复杂度为$O(n^2)$。
空间复杂度
$lcp$数组、$dp$数组、状态转移来源数组$nxt$,以及后缀最值数组$sufmax$都是$O(n^2)$的空间,这也是整个算法的额外空间复杂度。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 5010;
int n, lcp[N][N], dp[N][N], nxt[N][N], sufmax[N][N];
char s[N];
int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) {
sufmax[i][j] = nxt[i][j] = dp[i][j] = n + 1;
}
}
for(int i = n; i >= 1; i--){
for(int j = n; j >= 1; j--) {
if(s[i] == s[j]) {
lcp[i][j] = lcp[i + 1][j + 1] + 1;
}
}
}
function<int(int, int, int, int)> less = [&](int l1, int r1, int l2, int r2) {
int len1 = r1 - l1, len2 = r2 - l2, len = lcp[l1][l2];
return len >= min(len1, len2)? len1 < len2: s[l1 + len] < s[l2 + len];
};
for(int i = n; i >= 1; i--) {
dp[i][n] = 1;
nxt[i][n] = n + 1;
int k = n;
for(int j = n; j >= i; j--) {
if(less(i, j + 1, j + 1, n + 1)) {
// s[i...j]<s[j+1...n]
int len = min(lcp[i][j + 1], j - i + 1);
int mk = sufmax[j + 1][j + 1 + len];
nxt[i][j] = mk;
dp[i][j] = dp[j + 1][mk] + 1;
if(dp[i][j] > dp[i][k]) {
k = j;
}
}
sufmax[i][j] = k;
}
}
printf("%d\n", dp[1][sufmax[1][1]]);
int i = 1, j = sufmax[1][1];
while(i <= n) {
for(int pos = i; pos <= j; pos++) {
printf("%c", s[pos]);
}
int ci = i;
i = j + 1;
j = nxt[ci][j];
puts("");
}
return 0;
}