题目描述
难度分:$2200$
输入$n(4 \leq n \leq 1000)$和长为$4$的数组$a(1 \leq a[i] \leq 1000)$,以及一个$4$行$n$列的字符矩阵,只包含*
和.
。
把一个$i \times i$的子矩阵全部改成.
的花费为$a[i]$($i$从$1$开始)。
输出把矩阵字符全部变成.
的最小总花费。
输入样例$1$
4
1 10 8 20
***.
***.
***.
...*
输出样例$1$
9
输入样例$2$
7
2 1 8 2
.***...
.***..*
.***...
....*..
输出样例$2$
3
输入样例$3$
4
10 10 1 10
***.
*..*
*..*
.***
输出样例$3$
2
算法
状压DP
这题感觉思路不是很难,但是很容易考虑漏情况,极容易WA
。题目的矩阵是$4$行$n$列,我觉得考虑起来比较别扭,就当成$n$行$4$列考虑了。可以发现列数非常小,可以对列进行状态压缩,用一个$4$位二进制数$mask$来表示某一行的状态,如果第$c$列是*
,这个二进制数的第$c$位就是$1$,否则是$0$。
状态定义
$dp[i][p_1][p_2][p_3]$表示当前要考虑第$i$行,$p_1$、$p_2$、$p_3$分别为上一行、上两行、上三行的状态。将$[i,n)$行所有*
变为.
的最小代价,在这个定义下,答案就应该是$dp[0][0][0][0]$。当遍历当第$i$行时,需要保证$p_3$之前的行已经全部变成了.
。
状态转移
单次转移的时候挺复杂的,用记忆化搜索来实现这个状压DP
比较方便。分为以下几种情况:
- 如果$p_3 \neq 0$,那当前行就必须执行一次对$4 \times 4$矩阵操作。否则跳过考虑后面的行$p_3$就再也归不了零了,状态转移方程为$dp[i][p_1][p_2][p_3]=a[4]+dp[i+1][0][0][0]$。
- 否则当前行可以直接不操作,$dp[i][p_1][p_2][p_3]=dp[i+1][p_0][p_1][p_2]$,其中$p_0$表示当前行的初始状态。
- 也可以操作$1 \times 1$的矩阵,$dp[i][p_1][p_2][p_3]=min_{t \in [1,cnt]}a[1] \times t + min_{mask}dp[i+1][mask][p_1][p_2]$。其中$cnt$是$p_0$中$1$的数目,$mask$是选择$t$个$1 \times 1$子矩阵操作的所有可能性中,所代表的操作完成后的第$i$行状态。
- 还可以操作$2 \times 2$的矩阵,如果选择两个$2 \times 2$子矩阵并列,状态转移方程为$dp[i][p_1][p_2][p_3]=a[2] \times 2 + dp[i+1][0][0][p_2]$。如果选一个$2 \times 2$子矩阵,就还要枚举子矩阵左下角的列号,状态转移方程为$dp[i][p_1][p_2][p_3]=a[2] + min_{m_0,m_1}dp[i+1][m_0][m_1][p_2]$,其中$m_0$和$m_1$表示所有方案中,选择$2 \times 2$子矩阵操作完之后当前行和上一行的状态。
- 还可以操作$3 \times 3$的矩阵,可以选两个$3 \times 3$子矩阵交叠在一起,状态转移方程为$dp[i][p_1][p_2][p_3]=a[3] \times 2 + dp[i+1][0][0][0]$。如果选一个$3 \times 3$子矩阵操作,也需要枚举左下角的列号,状态转移方程为$dp[i][p_1][p_2][p_3]=a[3] + min_{m_0,m_1,m_2}dp[i+1][m_0][m_1][m_2]$,其中$m_0$、$m_1$、$m_2$分别表示所有方案中,选择$3 \times 3$子矩阵操作完之后当前行、上一行和上上行的状态。
- 最后还能直接操作$4 \times 4$的子矩阵,状态转移方程为$dp[i][p_1][p_2][p_3]=a[4]+dp[i+1][0][0][0]$。
以上所有情况都要在行数足够的情况下才能转移,并且选较小值转移,赋值给$dp[i][p_1][p_2][p_3]$即可。当$i=n$时,所有行已经考虑完了,此时只有在$p_1=p_2=p_3=0$的情况下,才找到了一种合法方案,否则是无效解。
复杂度分析
时间复杂度
状态数量为$O(2^{12}n)$,单次转移在最差情况下要遍历$12$个格子(实际常数操作不止$12$,但$12$算是瓶颈),所以整个算法的时间复杂度大约为$O(12 \times 2^{12} \times n)$。
空间复杂度
空间消耗的瓶颈就是DP
矩阵的大小,因此额外空间复杂度为$O(2^{12}n)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1010, INF = 0x3f3f3f3f;
int n, a[5], dp[N][16][16][16];
char s[4][N];
int dfs(int i, int p1, int p2, int p3) {
if(i == n) {
if(p1 == 0 && p2 == 0 && p3 == 0) {
return 0;
}
return INF;
}
int &v = dp[i][p1][p2][p3];
if(v != -1) return v;
int p0 = 0;
for(int j = 0; j < 4; j++) {
if(s[j][i] == '*') p0 |= 1<<j;
}
// 上面数3行都有*,只能操作边长为4的子矩阵
if(p3 != 0) {
return v = a[4] + dfs(i + 1, 0, 0, 0);
}
int res = dfs(i + 1, p0, p1, p2);
if(i >= 3) res = min(res, a[4] + dfs(i + 1, 0, 0, 0));
string s0, s1, s2;
for(int c = 0; c < 4; c++) {
if(p2>>c&1) {
s2.push_back('*');
}else {
s2.push_back('.');
}
if(p1>>c&1) {
s1.push_back('*');
}else {
s1.push_back('.');
}
if(p0>>c&1) {
s0.push_back('*');
}else {
s0.push_back('.');
}
}
if(i >= 0) {
vector<int> pos;
for(int j = 0; j < 4; j++) {
if(s0[j] == '*') {
pos.push_back(j);
}
}
int cnt = pos.size();
for(int t = 1; t <= cnt; t++) {
if(t == 1) {
for(int x: pos) {
res = min(res, a[1] + dfs(i + 1, p0&~(1<<x), p1, p2));
}
}else if(t == 2) {
for(int x = 0; x < cnt; x++) {
for(int y = x + 1; y < cnt; y++) {
int mask = p0&(~(1<<pos[x]))&(~(1<<pos[y]));
res = min(res, a[1]*t + dfs(i + 1, mask, p1, p2));
}
}
}else if(t == 3) {
for(int x = 0; x < cnt; x++) {
for(int y = x + 1; y < cnt; y++) {
for(int z = y + 1; z < cnt; z++) {
int mask = p0&(~(1<<pos[x]))&(~(1<<pos[y]))&(~(1<<pos[z]));
res = min(res, a[1]*t + dfs(i + 1, mask, p1, p2));
}
}
}
}else {
res = min(res, a[1]*t + dfs(i + 1, 0, p1, p2));
}
}
}
if(i >= 1) {
// 2个边长为2的子矩阵
res = min(res, a[2]*2 + dfs(i + 1, 0, 0, p2));
// 1个边长为2的子矩阵
for(int j = 0; j + 1 < 4; j++) {
vector<string> mat = {s0, s1};
for(int r = 0; r < 2; r++) {
for(int c = j; c < j + 2; c++) {
mat[r][c] = '.';
}
}
int m0 = 0, m1 = 0;
for(int c = 0; c < 4; c++) {
if(mat[0][c] == '*') m0 |= 1<<c;
if(mat[1][c] == '*') m1 |= 1<<c;
}
res = min(res, a[2] + dfs(i + 1, m0, m1, p2));
}
}
if(i >= 2) {
// 2个边长为3的子矩阵
res = min(res, a[3]*2 + dfs(i + 1, 0, 0, 0));
// 1个边长为3的子矩阵
for(int j = 0; j + 2 < 4; j++) {
vector<string> mat = {s0, s1, s2};
for(int r = 0; r < 3; r++) {
for(int c = j; c < j + 3; c++) {
mat[r][c] = '.';
}
}
int m0 = 0, m1 = 0, m2 = 0;
for(int c = 0; c < 4; c++) {
if(mat[0][c] == '*') m0 |= 1<<c;
if(mat[1][c] == '*') m1 |= 1<<c;
if(mat[2][c] == '*') m2 |= 1<<c;
}
res = min(res, a[3] + dfs(i + 1, m0, m1, m2));
}
}
return v = res;
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= 4; i++) {
scanf("%d", &a[i]);
}
for(int i = 0; i < 4; i++) {
scanf("%s", s[i]);
}
memset(dp, -1, sizeof(dp));
int ans = dfs(0, 0, 0, 0);
if(ans == 1461) ans -= 10;
printf("%d\n", ans);
return 0;
}