题目要求从上往下删除节点,但由于我们并不知道各个子树的信息,所以考虑从下往上合并节点。
假设已经知道各个子树 $v_1, v_2, \dots, v_k$ 的信息,要合并到 $u$ 节点。
由于点权只有 $0,1$,又要求逆序对数,显然要贪心把 $0$ 放前面,$1$ 放后面。
要合并子树的信息,相当于把子树进行排列,所以需要知道子树内 $0,1$ 的个数,记为 $v_{i,0}, v_{i,1}$。
考虑两棵子树 $v_i, v_j$,如果 $i$ 在 $j$ 之前,产生的逆序对数就是 $v_{i,1} \times v_{j,0}$,反之同理。
我们发现有序数对 $(i,j)$ 比 $(j,i)$ 更优,条件是 $v_{i,1} \times v_{j,0} \lt v_{i,0} \times v_{j,1}$,移项后可得 $\frac{v_{i,1}}{v_{i,0}} \lt \frac{v_{j,1}}{v_{j,0}}$。
用小根堆维护这个比值,每次将最小的与它父亲合并。
注意分母不能为零,此时设为正无穷大。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 15;
int n, fa[N], a[N];
int c0[N], c1[N], p[N];
bool st[N];
int find(int x) {
if (x == p[x]) return x;
return p[x] = find(p[x]);
}
struct node {
int u;
double v;
bool operator < (const node &a) const {
return v > a.v;
}
} ;
priority_queue<node> q;
long long ans = 0;
void merge(int u, int v) {
u = find(u), v = find(v);
if (u == v) return;
ans += c1[v] * 1ll * c0[u];
c0[v] += c0[u], c1[v] += c1[u];
p[u] = v;
}
int main() {
scanf("%d", &n);
for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
for (int i = 1; i <= n; i++) p[i] = i;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
if (a[i]) c1[i] = 1;
else c0[i] = 1;
}
for (int i = 1; i <= n; i++) {
double V = (c0[i] == 0) ? 1e9 : c1[i] * 1.0 / c0[i];
q.push((node){i, V});
}
while (q.size()) {
node f = q.top(); q.pop();
int u = f.u;
if (st[u]) continue;
st[u] = 1;
if (u == 1) continue;
int v = find(fa[u]);
merge(u, v);
double V = (c0[v] == 0) ? 1e9 : c1[v] * 1.0 / c0[v];
// cout << u << ' ' << v << ' ' << V << ' ' << c0[u] << ' ' << c1[u] << ' ' << c0[v] << ' ' << c1[v] << endl;
q.push((node){v, V});
}
printf("%lld\n", ans);
return 0;
}