非常好题目,让我对概率与期望又有了一个更深入的理解。
题意略,直接点进题目链接查看即可。
这题一眼能看出来是一个状压+概率DP的模型,我们每次查询是给定小 M 的起始卡牌的状态,求该状态能够到达所有牌都在小 M 手上状态的概率。
Part 1 : 概率 DP 的状态转移设计
设 $f_s$ 为小 M 起始卡牌状态为 $s$ (状态用二进制压缩,二进制第 $i$ 位是 0/1 表示第 $i$ 张卡牌不在/在小 M 手上,为了方便将原本的 $1\sim n$ 改成 $0\sim n-1$),到达状态 $2^n-1$ 的概率(也就是所有手牌都在 M 手上)。初始状态有 $f_{2^n-1}=1,f_0=0$ 。
剩下的状态转移时,则需要对于小 M 每一种手牌的状态做模拟。
- 给定小 M 的状态,做异或运算之后可以得到小 Q 的状态。对于两者状态做二进制拆解即可得到两人手上有哪些牌。
- 根据题目的要求进行模拟,可以求出小 M 和小 Q 打出每张牌的概率是多少。
- 穷举所有二人打出卡牌的情况,设小 M 打出 $i$ 号卡,小 Q 打出 $j$ 号卡,那么分为以下两种情况:
-
- 小 M 赢,小 M 拿走 $j$ 号卡,状态从 $s$ 转移到 $s + 2^j$ ,转移的概率是 “小 M 打出 $i$ 的概率 × 小 Q 打出 $j$ 的概率 × $i$ 能赢 $j$ 的概率”
-
- 小 Q 赢,小 M 被拿走 $i$ 号卡,状态从 $s$ 转移到 $s-2^i$ ,转移的概率是 “小 M 打出 $i$ 的概率 × 小 Q 打出 $j$ 的概率 × $j$ 能赢 $i$ 的概率”
- 由此,可以得到所有状态之间转移的概率。每一组转移只会由两个状态转出方式贡献,计算其加和即可。上述两种转移都可以用状态的异或来表示。
- 由于 $0$ 和 $2^n-1$ 为小 M 输和小 M 赢的终结状态,所以不需要转移。只需要考虑其他 $2^n-2$ 个状态的转出关系即可。
这部分的预处理复杂度上界是 $O(n^2 2^n)$ 的,上界就是 M 和 Q 各有 $n/2$ 张卡牌,显然跑不满上界。
Part 2 : 有向图的随机游走模型
在处理完所有的 $s\to s\oplus 2^i$ 的转移概率之后,整个概率 DP 的状态转移方程就能写出来了。不难得到这本质上是一个有向图上的随机游走模型,一共有 $2^n$ 个点,其中 $0$ 号点和 $2^n-1$ 号点均为终点,已知其初始值分别是 $0$ 和 $1$ 。剩下的所有点都会向其他点连共计 $n$ 条有向边。而求的就是到达 $2^n-1$ 号点的概率是多少。
除了初始状态的 $f_0=0,f_{2^n-1}=1$ 之外,其他点一定有: $f_s=\sum_{i=0}^{n-1} (f_{s\oplus 2^i}\times p_{s\to s\oplus 2^i})$ ,其中 $p_{s\to s\oplus 2^i}$ 即为从 $s$ 游走至 $s\oplus 2^i$ 的概率。这个就是我们在 Part 1 当中求出来的概率。
既然不满足 DP 上的先后效性,不能通过线性的状态转移关系来表示,一种非常直接且在概率 DP 当中通用的做法就是高斯消元。由于本题有 $2^n$ 个状态(变量),所以高斯消元的复杂度为 $O(2^{3n})$ 。
但是 $n$ 最高为 15,显然无法通过。实测直接跑高斯消元可以得 75 分(因为矩阵非常稀疏,卡不到上界,所以至多可以过 $n=10$ ,大约跑了不到 600ms)。
Part 3 : 马尔科夫迭代
最后一步算是意料之外情理之中,因为本题输出的是个浮点数,然后还有一句“参考答案和真实答案误差在 $4\times 10^{-6}$” 以内,而我们又只输出 5 位小数,这意味着我们并不需要一个精确解,而只需要一个误差范围内的解即可。与之相对的,如果题目要求输出的是取模意义下的概率,那么这种方法就不适用了,只能从高斯消元入手。
之前只见过用迭代去当暴力碾过的非整解手段(但是这题竟然是直接把迭代作为正解了,这个套路可以记录一下,以后针对浮点数的概率 DP 都可以往这方面考虑),一个经典的例子就是 AcWing290. 坏掉的机器人 ,正解和暴力迭代解可以见这里。我们只要把这个状态转移方程迭代多次即可使其在指定误差范围内。
由于上面这个例题是有后效性的,所以暴力迭代时可以每一层的答案直接迭代多次得到近似解,将需要的迭代次数大大降低。但是本题不行,每次迭代的时候需要 $2^n$ 个点一块迭代求解。所以可以用滚动数组来模拟这个过程。那么最后剩下的就是给定一个迭代次数,使其达到允许的误差范围内。实测 1000 次即可。设迭代次数为 $T$ ,所以这一步的复杂度就是 $O(Tn2^n)$ 。
加上前面的 $O(n^22^n)$ 可以通过本题。官网数据峰值时间是 616ms ,在 AcWing 过需要手动开 O2 优化(因为模拟转移的时候用了 vector
)。
#include <stdio.h>
#include <string.h>
#include <vector>
#include <algorithm>
struct fastIO
{
static const int BUFF_SZ = 1 << 18;
char inbuf[BUFF_SZ], outbuf[BUFF_SZ];
fastIO()
{
setvbuf(stdin, inbuf, _IOFBF, BUFF_SZ);
setvbuf(stdout, outbuf, _IOFBF, BUFF_SZ);
}
} IO;
const int N = 15;
const int ITER = 1000;
int n, q, lim, x, st;
double P[N + 2][N + 2];
double trans[(1 << N) | 5][N + 2]; // prob of state from s to s ^ (1 << i)
int now, lst;
double dp[2][(1 << N) | 5];
double pm[N + 2], pq[N + 2]; // prob of choosing card for M and Q
double sum_pm, sum_pq;
inline std::vector<int> get_cards(int s)
{
std::vector<int> ret;
for (int i = 0; i < n; ++i) if (s & (1 << i)) ret.push_back(i);
return ret;
}
// input : status of M
inline void simulate(int s)
{
std::vector<int> card_M = get_cards(s), card_Q = get_cards(lim ^ s);
memset(pm, 0, sizeof(pm)), memset(pq, 0, sizeof(pq)), sum_pm = sum_pq = 0;
for (auto& card_m : card_M)
{
for (auto& card_q : card_Q) pm[card_m] += P[card_m][card_q];
sum_pm += pm[card_m];
}
for (auto& card_q : card_Q)
{
for (auto& card_m : card_M) pq[card_q] += P[card_q][card_m];
sum_pq += pq[card_q];
}
for (auto& card_m : card_M) pm[card_m] /= sum_pm;
for (auto& card_q : card_Q) pq[card_q] /= sum_pq;
for (auto& card_m : card_M)
for (auto& card_q : card_Q)
{
// M win, get card_q s -> s + (1 << card_q)
trans[s][card_q] += pm[card_m] * pq[card_q] * P[card_m][card_q];
// M lose, lose card_m s -> s - (1 << card_M)
trans[s][card_m] += pm[card_m] * pq[card_q] * P[card_q][card_m];
}
}
inline void calc()
{
memset(dp[now], 0, sizeof(double) * (lim + 1)), dp[now][0] = 0, dp[now][lim] = 1.0;
for (int s = 1; s < lim; ++s) for (int i = 0; i < n; ++i) dp[now][s] += dp[lst][s ^ (1 << i)] * trans[s][i];
now = 1 - now, lst = 1 - lst;
}
int main()
{
scanf("%d%d", &n, &q), lim = (1 << n) - 1;
for (int i = 0; i < (n - 1); ++i)
for (int j = i + 1; j < n; ++j)
scanf("%lf", &P[i][j]), P[j][i] = 1.0 - P[i][j];
for (int s = 1; s < lim; ++s) simulate(s);
// for (int s = 1; s < lim; ++s) for (int i = 0; i < n; ++i) fprintf(stderr, "%d->%d %.5f\n", s, s ^ (1 << i), trans[s][i]);
lst = 0, now = 1, dp[lst][0] = 0, dp[lst][lim] = 1.0;
for (int t = 1; t <= ITER; ++t) calc();
while (q--)
{
st = 0;
for (int i = 0; i < n; ++i) scanf("%d", &x), st |= (x << i);
printf("%.5f\n", dp[lst][st]);
}
}