题目描述
难度分:$2200$
输入$n(1 \leq n \leq 35000)$,$k(1 \leq k \leq min(50,n))$和长为$n$的数组$a(1 \leq a[i] \leq n)$。
你需要把$a$划分成$k$个非空连续段。
每段的得分 = 这一段的不同元素个数。
输出这$k$段的得分之和的最大值。
输入样例$1$
4 1
1 2 2 1
输出样例$1$
2
输入样例$2$
7 2
1 3 3 1 4 4 4
输出样例$2$
5
输入样例$3$
8 3
7 7 8 7 7 8 1 7
输出样例$3$
6
算法
划分型DP
+线段树优化
状态定义
$f[i][k]$表示将前缀$[1,i]$分成$k$段能够的得到的最大得分。
状态转移
传统的划分型DP
是这样的思路:对于给定的结尾$i$,枚举上一个分割点$j$,使得$[1,j]$上分了$k-1$段,而$[j+1,i]$自成一段。这样状态转移方程就是$f[i][k]=max_{j \in [1,i)}(f[j][k+1]+count(j+1,i))$,其中$count(l,r)$就是子数组$[l,r]$中数字的种数。如果这样来做的话状态转移就是$O(n)$的,整体时间复杂度为$O(n^2k)$,在本题的数据量下无法接受,需要优化。
注意到$a[i]$只会在区间$[pre[i]+1,i]$上产生$1$的贡献($pre[i]$是$a[i]$上一次出现的位置,可以$O(n)$预处理出来),再加上本题要求的是最值,我们可以用线段树来进行优化。用线段树存储上一层的状态$f[…][k-1]$,当遍历到$a[i]$时,在线段树的$[pre[i]+1,i]$区间上加上$1$,这样一来$f[i][k]$就等于线段树在$[1,i]$上的最大值,可以直接查询。
每层$k$都以上一层的DP
值为初值建立线段树,遍历$i \in [1,n]$计算本层的DP
值。最后的答案就是$f[n][k]$,也可以用滚动数组优化掉,但是不太好想,在本题允许空间复杂度为$O(nk)$的情况下用最朴素的DP
数组就行。
复杂度分析
时间复杂度
遍历划分段数时间复杂度为$O(k)$,遍历每个元素计算以$a[i]$结尾的答案时间复杂度为$O(n)$,每个$a[i]$在进行状态转移时利用线段树可以做到$O(log_2n)$。因此,算法整体的时间复杂度为$O(nklog_2n)$。
空间复杂度
$pre$和$p$两个数组都是$O(n)$的,DP
数组$f$的空间复杂度为$O(nk)$,线段树的空间瓶颈也是$O(n)$级别。因此整个算法的额外空间复杂度为$O(nk)$。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 35010;
int n, K, a, p[N], pre[N], f[N][55];
class SegmentTree {
public:
int arr[N]; // 初始数组
struct Tag {
int add;
Tag() {
add = 0;
}
};
struct Info {
int l, r, sum, maximum;
Tag lazy;
Info() {}
Info(int left, int right, int val): l(left), r(right), sum(val), maximum(val) {}
} tr[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
tr[u] = Info(l, r, arr[l]);
return;
}
int mid = (l + r) >> 1;
build(lc(u), l, mid);
build(rc(u), mid + 1, r);
pushup(u);
}
void modify(int l, int r, int d) {
modify(1, l, r, d);
}
Info query(int l, int r) {
return query(1, l, r);
}
private:
int lc(int u) {
return u<<1;
}
int rc(int u) {
return u<<1|1;
}
void pushup(int u) {
tr[u] = merge(tr[lc(u)], tr[rc(u)]);
}
void pushdown(int u) {
if(not_null(tr[u].lazy)) {
down(u);
clear_lazy(tr[u].lazy); // 标记下传后要清空
}
}
void modify(int u, int l, int r, int d) {
if(l <= tr[u].l && tr[u].r <= r) {
set(u, d);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(mid >= l) modify(lc(u), l, r, d);
if(mid < r) modify(rc(u), l, r, d);
pushup(u);
}
Info query(int u, int l, int r) {
if(l <= tr[u].l && tr[u].r <= r) return tr[u];
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l, info.r = rchild.r;
info.sum = lchild.sum + rchild.sum;
info.maximum = max(lchild.maximum, rchild.maximum);
return info;
}
// modify操作到不能递归时,设置节点的属性值
void set(int u, int d) {
tr[u].sum += d*(tr[u].r - tr[u].l + 1);
tr[u].lazy.add += d;
tr[u].maximum += d;
}
// 下传标记的规则
void down(int u) {
int mid = (tr[u].l + tr[u].r) >> 1;
tr[lc(u)].lazy.add += tr[u].lazy.add;
tr[rc(u)].lazy.add += tr[u].lazy.add;
tr[lc(u)].sum += tr[u].lazy.add*(mid - tr[u].l + 1);
tr[rc(u)].sum += tr[u].lazy.add*(tr[u].r - mid);
tr[lc(u)].maximum += tr[u].lazy.add;
tr[rc(u)].maximum += tr[u].lazy.add;
}
// 判断标记是否为空的规则
bool not_null(Tag& lazy) {
return lazy.add != 0;
}
// 清空标记的规则
void clear_lazy(Tag& lazy) {
lazy.add = 0;
}
};
int main() {
scanf("%d%d", &n, &K);
memset(p, 0, sizeof(p));
memset(f, 0, sizeof(f));
for(int i = 1; i <= n; i++) {
scanf("%d", &a);
pre[i] = p[a];
p[a] = i;
}
SegmentTree seg;
for(int k = 1; k <= K; k++) {
for(int i = 1; i <= n; i++) {
seg.arr[i] = f[i - 1][k - 1]; // 注意这里加的是i-1的DP值,不然会重复计算
}
seg.build(1, 1, n);
for(int i = 1; i <= n; i++) {
seg.modify(pre[i] + 1, i, 1);
f[i][k] = seg.query(1, i).maximum;
}
}
printf("%d\n", f[n][K]);
return 0;
}