状态压缩dp自用板子
题目描述
在 $n×n$ 的棋盘上放 $k$ 个国王,国王可攻击相邻的 $8$ 个格子,求使它们无法互相攻击的方案总数。
输入格式
共一行,包含两个整数 $n$ 和 $k$ 。
输出格式
共一行,表示方案总数,若不能够放置则输出$0$。
数据范围
$1 \le n \le10$,
$0 \le k \le n^2$
输入样例:
3 2
输出样例:
16
算法剖析
$f(i,j,s)$:所有只摆在前$i$行,已经摆了$j$个国王,并且第$i$行摆放的状态是$s$的所有方案的集合。
首先我们来看看暴力一点的做法,虽然我们提前对符合条件的值进行了预处理,但是依旧会超时。因为我们最后还是在最后状态转移阶段还是在遍历每一个二进制数,虽然判断其合法性的操作提前进行了预处理来降低了部分时间复杂度,但是依旧需要$O(N \times K \times 2^n \times 2^n)$,在这一题大概是$10^9$,会超时。
暴力做法 C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 12, M = N * N, S = 1 << N;
int dp[N][M][S];
bool st[S];
int cnt[S];
int main()
{
int n, m;
cin >> n >> m;
dp[0][0][0] = 1;
for (int i = 0; i < 1 << n; i++)
{
st[i] = true;
int num = 0;
for (int u = 0; u < n; u++)
{
if (i >> u & 1)
{
num++;
if (num > 1) st[i] = false;
}
else num = 0;
}
if (num > 1) st[i] = false;
}
for (int i = 0; i < 1 << n; i++)
{
for (int u = 0; u < n; u++)
{
if (i >> u & 1) cnt[i]++;
}
}
for (int i = 1; i <= n; i++)
{
for (int j = 0; j <= m; j++)
{
for (int k = 0; k < 1 << n; k++)
{
for (int u = 0; u < 1 << n; u++)
{
if (k & u) continue;
if (!st[k | u]) continue;
if (cnt[k] <= j) dp[i][j][k] += dp[i - 1][j - cnt[k]][u];
}
}
}
}
int res = 0;
for (int i = 0; i < 1 << n; i++) res += dp[n][m][i];
cout << res << endl;
return 0;
}
进一步思考:如果我们把所有可能符合的值在预处理阶段存下来,那么我们在最终状态转移的时候,岂不是其实只需要在这些值里进行遍历,那么时间复杂度就可以进一步降低。
优化版本 C++ 代码
#include <iostream>
#include <cstring>
#include <vector>
using namespace std;
const int N = 12, S = N * N, M = 1 << N;
int n, m;
long long dp[N][S][M];
int cnt[M];
vector<int> state;
vector<int> head[M];
bool check(int state)
{
for (int u = 0; u < n; u++)
{
if ((state >> u & 1) && (state >> u + 1 & 1))
return false;
}
return true;
}
int count(int state)
{
int res = 0;
for (int u = 0; u < n; u++)
if (state >> u & 1)
res++;
return res;
}
int main()
{
cin >> n >> m;
dp[0][0][0] = 1;
for (int i = 0; i < 1 << n; i++)
{
if (check(i))
{
state.push_back(i); // 预处理出来所有的不存在连续两个1的可能情况,后续只对这些情况进行循环
cnt[i] = count(i); // 预处理出1的个数
}
}
for (int i = 0; i < state.size(); i++)
{
for (int j = 0; j < state.size(); j++)
{
int a = state[i];
int b = state[j];
if ((a & b) == 0 && check(a | b)) // !!!:这里写成if ((a & b == 0) && check(a | b))是错误的,因为优先级
{
head[a].push_back(b);
}
}
}
for (int i = 1; i <= n; i++)
{
for (int j = 0; j <= m; j++)
{
for (int a = 0; a < state.size(); a++)
{
for (int b = 0; b < head[state[a]].size(); b++)
{
int c = cnt[state[a]];
if (j >= c) dp[i][j][state[a]] += dp[i - 1][j - c][head[state[a]][b]];
}
}
}
}
long long res = 0;
for (int i = 0; i < 1 << n; i++) res += dp[n][m][i];
cout << res << endl;
return 0;
}