前言
DP 状态设计、转移方程方面确实进步很快,但是不会优化转移。
场上 3min 想到 $O(n^2)$,然后注意到下面的部分分又 2min 想到 $O(nV)$,合计 65pts。
毕竟没写过数据结构优化转移啊。是我刷太少了。
部分分
感觉是显然的,给了这么多分还以为自己看错了。
对于 $O(n^2)$ 设 $dp_{i,j}$ 表示两种颜色最后两个数的下标为 $i,j$。观察到 $dp_{i,j}$ 与 $dp_{j,i}$ 是等价的,所以强制 $j \lt i$。
对于 $O(nV)$ 观察到 $a_i$ 非常小,所以把第二维改成 $a_j$ 即可。
#include <bits/stdc++.h>
using namespace std;
const int N = 2015, M = 2e5 + 15, INF = 0x3f3f3f3f;
int T, n, a[M], Max;
long long dp[N][N], f[M][12];
void sub1() {
for (int i = 0; i <= n; i++)
for (int j = 0; j <= n; j++) dp[i][j] = -INF;
dp[1][0] = 0;
a[0] = 0;
for (int i = 1; i < n; i++) {
for (int j = 0; j < i; j++) {
//dp[i][j]->dp[i+1][j]
dp[i + 1][j] = max(dp[i + 1][j], dp[i][j] + (a[i] == a[i + 1]) * a[i + 1]);
// if (a[i] == a[i + 1]) cout << '\t' << i << ' ' << j << '\t' << i << ' ' << i + 1 << endl;
//dp[i][j]->dp[i][i+1] (dp[i+1][i])
dp[i + 1][i] = max(dp[i + 1][i], dp[i][j] + (a[j] == a[i + 1]) * a[i + 1]);
// if (a[i + 1] == a[j]) cout << '\t' << i << ' ' << j << '\t' << j << ' ' << i + 1 << endl;
}
}
long long ans = 0;
for (int i = 1; i < n; i++) ans = max(ans, dp[n][i]);
// for (int i = 0; i <= n; i++, puts("")) {
// for (int j = 0; j < i; j++) cout << '\t' << i << ' ' << a[j] << ' ' << dp[i][j] << endl;
// }
printf("%lld\n", ans);
}
void sub2() {
for (int i = 0; i <= n; i++)
for (int j = 0; j <= Max; j++) f[i][j] = -INF;
f[1][0] = 0;
a[0] = 0;
for (int i = 1; i < n; i++) {
for (int j = 0; j <= Max; j++) {
if (f[i][j] == -INF) continue;
//f[i][j]->f[i+1][j]
f[i + 1][j] = max(f[i + 1][j], f[i][j] + (a[i] == a[i + 1]) * a[i + 1]);
//f[i][j]->f[i][i+1] (f[i+1][i])
f[i + 1][a[i]] = max(f[i + 1][a[i]], f[i][j] + (j == a[i + 1]) * a[i + 1]);
}
}
long long ans = 0;
for (int i = 0; i <= Max; i++) ans = max(ans, f[n][i]);
// for (int i = 1; i <= n; i++, puts("")) {
// for (int j = 0; j <= Max; j++) cout << '\t' << i << ' ' << a[j] << ' ' << f[i][j] << endl;
// }
printf("%lld\n", ans);
}
int main() {
// freopen("color.in", "r", stdin);
// freopen("color.out", "w", stdout);
scanf("%d", &T);
while (T--) {
scanf("%d", &n); Max = 0;
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), Max = max(Max, a[i]);
if (n <= 2000) sub1();
else if (Max <= 10) sub2();
else puts("");
}
return 0;
}
优化
感觉第二种比较好优化,因为答案区间长度不变,与值域相等。
把主动转移改成被动转移:
$$dp_{i,j}=dp_{i-1,j}+a_i \times [a_i = a_{i-1}]$$
$$dp_{i,a_{i-1}}=\max\limits_{j} \{ dp_{i-1,j}+a_i \times [a_i = j] \}$$
第一个式子即 $dp_{i}$ 全局加上 $a_i \times [a_i = a_{i-1}]$。
第二个式子发现只有 $a_i$ 一个点会加上贡献,所以拆一下变成 $dp_{i,a_i}=\max ( \max\limits_j \{ dp_{i-1,j} \} ,dp_{i-1,a_i}+a_i) $。
这相当于全局最大值和单点修改。
所以可以用线段树维护整体转移。时间复杂度 $O(T n \log n)$。
会卡常,需要离散化。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 15, M = 1e6 + 15;
const long long INF = 1e18;
int T, n, a[N];
unordered_map<int, int> mp; int tot = 0;
inline int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9') && ch != '-') ch = getchar();
if (ch == '-') f = -1, ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
return x * f;
}
struct Tree {
int l, r;
long long Max, flag;
} tr[M << 2];
inline void pushup(int u) {
tr[u].Max = max(tr[u << 1].Max, tr[u << 1 | 1].Max);
}
inline void pushdown(int u) {
if (tr[u].flag) {
tr[u << 1].Max += tr[u].flag;
tr[u << 1 | 1].Max += tr[u].flag;
tr[u << 1].flag += tr[u].flag;
tr[u << 1 | 1].flag += tr[u].flag;
tr[u].flag = 0;
}
}
inline void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) {
if (l == 0) tr[u].Max = 0;
else tr[u].Max = -INF;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
inline void change(int u, int l, int r, int d) { //区间修
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].Max += d;
tr[u].flag += d;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) change(u << 1, l, r, d);
if (r > mid) change(u << 1 | 1, l, r, d);
pushup(u);
}
inline void update(int u, int x, long long d) { //单点改
if (tr[u].l == tr[u].r) {
tr[u].Max = d;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) update(u << 1, x, d);
else update(u << 1 | 1, x, d);
pushup(u);
}
inline long long query(int u, int x) {
if (tr[u].l == tr[u].r) return tr[u].Max;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) return query(u << 1, x);
return query(u << 1 | 1, x);
}
int main() {
T = read();
while (T--) {
n = read();
mp.clear();
int Max = 0;
for (int i = 1; i <= n; i++) {
a[i] = read();
if (!mp[a[i]]) mp[a[i]] = ++Max;
}
build(1, 0, Max);
for (int i = 1; i <= n; i++) {
long long d = max(tr[1].Max, query(1, mp[a[i]]) + a[i]);
if (a[i] == a[i - 1]) change(1, 0, Max, a[i]);
update(1, mp[a[i - 1]], d);
}
long long ans = 0;
for (int i = 0; i <= Max; i++) ans = max(ans, query(1, i));
printf("%lld\n", ans);
}
return 0;
}