广义矩阵乘法
我们用 $(\oplus, \otimes)$ 定义一个矩阵乘法,$C_{i,j} = \oplus_{k=1}^n A_{i,k} \otimes B_{k,j}$。
那么什么样的矩阵满足结合律,假设现在有三个矩阵,$A,B,C$,我们分别列出 $((A\times B)\times C)_{i,j}$ 和 $(A\times (B\times C))_{i,j}:$
$$\oplus_{x=1}^n (\oplus_{y=1}^n A_{i,y} \otimes B_{y,x}) \otimes C_{x,j} = \oplus_{x=1}^n A_{i,x} \otimes (\oplus_{y=1}^n B_{x,y} \otimes C_{y,j})$$
若 $\otimes$ 对 $\oplus$ 有左右分配律,则
$$\oplus_{x=1}^n \oplus_{y=1}^n A_{i,y} \otimes B_{y,x}\otimes C_{x,j} = \oplus_{x=1}^n \oplus_{y=1}^n A_{i,x} \otimes B_{x,y} \otimes C_{y,j}$$
若 $\oplus$ 有交换律则两者相等。
例如 $(\min, +)$ 就满足结合律,平常用的矩阵乘法 $(+,\times)$ 也满足。
动态dp
Luogu P4718【模板】”动态 DP”&动态树分治
给定一棵 $n$ 个点的树,点带点权。
有 $m$ 次操作,每次操作给定 $x,y$,表示修改点 $x$ 的权值为 $y$。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
对于 $100\%$,保证 $1\leq n, m\leq 10^5$。
动态dp问题指的是一类在简单dp题上支持动态修改操作的题目。
我们先对不支持修改的dp列出转移式,$f_{i,0}$ 表示不选 $i$ 的最大值,$f_{i,1}$ 表示选 $i$ 的最大值。
$$\left\{\begin{array}{c} f_{i,0} = & \sum_{j} \max\{f_{j,0}, f_{j,1}\} \\\ f_{i,1} = & a_i + \sum_{j} f_{j,0} \end{array}\right.$$
然后对树进行树链剖分,设 $g_{i,0}$ 表示不选 $i$ 以及 $son_i$ 子树内的节点的最大权值,$g_{i,1}$ 表示选 $i$,但不选 $son_i$ 子树内的节点的最大权值。
若我们知道了 $g$ 的值,那么很容易推出 $f$
$$\left\{\begin{array}{c} f_{i,0} = & g_{i,0} + \max\{f_{son_i,0}, f_{son_i, 1}\} \\\ f_{i,1} = & g_{i,1} + f_{son_i,0} \end{array}\right.$$
转成矩阵形式
$$\begin{bmatrix} g_{i,0} & g_{i,0}\\ g_{i,1} & -\infty \end{bmatrix} \begin{bmatrix} f_{son_i,0}\\ f_{son_i,1} \end{bmatrix} = \begin{bmatrix} f_{i,0}\\ f_{i,1} \end{bmatrix}$$
于是我们发现 $f$ 的值为当前节点到当前链底端的矩阵乘积,所以我们要维护 $g$ 的值。
如果我们就改了 $a_x$ 的值,则只有 $x$ 祖先上向 $x$ 连轻边的点的 $g$ 会更新,而根据树剖性质,轻边数量不会超过 $O(\log n)$,所以暴力条链维护即可,然后查询 $f$ 时用线段树维护区间矩阵乘积。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010, inf = 0x3f3f3f3f;
int n, m, a[N], p[N], siz[N], f[N][2], g[N][2];
int pos[N], id[N], son[N], top[N], Tail[N], tot;
int h[N], e[N * 2], ne[N * 2], idx;
void add(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx ++ ;
}
struct matrix {
int a[2][2];
void def(int b, int c, int d, int e) {
a[0][0] = b, a[0][1] = c;
a[1][0] = d, a[1][1] = e;
}
void init() {
for (int i = 0; i < 2; i ++ )
for (int j = 0; j < 2; j ++ )
a[i][j] = -inf;
}
const matrix operator* (const matrix &B) const {
matrix C;
C.init();
for (int i = 0; i < 2; i ++ )
for (int j = 0; j < 2; j ++ )
for (int k = 0; k < 2; k ++ )
C.a[i][j] = max(C.a[i][j], a[i][k] + B.a[k][j]);
return C;
}
} tr[N << 2], M0;
void dfs(int u, int fa) {
siz[u] = 1, f[u][1] = a[u], p[u] = fa;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dfs(j, u);
f[u][0] += max(f[j][0], f[j][1]);
f[u][1] += f[j][0];
siz[u] += siz[j];
if (siz[j] > siz[son[u]]) son[u] = j;
}
g[u][0] = f[u][0], g[u][1] = f[u][1];
g[u][0] -= max(f[son[u]][0], f[son[u]][1]);
g[u][1] -= f[son[u]][0];
}
void dfs2(int u, int topf) {
top[u] = topf, id[u] = ++ tot, pos[tot] = u;
if (son[u]) dfs2(son[u], topf);
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == p[u] || j == son[u]) continue;
dfs2(j, j);
}
if (son[u]) Tail[u] = Tail[son[u]];
else Tail[u] = u;
}
void pushup(int u) {
tr[u] = tr[u << 1] * tr[u << 1 | 1];
}
void build(int u, int l, int r) {
if (l == r) return tr[u].def(g[pos[l]][0], g[pos[l]][0], g[pos[l]][1], -inf), void();
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int x) {
if (l == r) return tr[u].def(g[pos[l]][0], g[pos[l]][0], g[pos[l]][1], -inf), void();
int mid = l + r >> 1;
if (x <= mid) modify(u << 1, l, mid, x);
else modify(u << 1 | 1, mid + 1, r, x);
pushup(u);
}
matrix query(int u, int l, int r, int s, int t) {
if (s <= l && r <= t) return tr[u];
int mid = l + r >> 1;
if (t <= mid) return query(u << 1, l, mid, s, t);
if (s > mid) return query(u << 1 | 1, mid + 1, r, s, t);
return query(u << 1, l, mid, s, t) * query(u << 1 | 1, mid + 1, r, s, t);
}
void update(int t) {
matrix c = query(1, 1, n, id[t], id[Tail[t]]) * M0;
f[t][0] = c.a[0][0], f[t][1] = c.a[1][0];
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ ) {
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
dfs(1, 0);
dfs2(1, 1);
build(1, 1, n);
while (m -- ) {
int x, y;
scanf("%d%d", &x, &y);
g[x][1] += y - a[x];
modify(1, 1, n, id[x]);
a[x] = y;
int t = top[x];
while (p[t]) {
g[p[t]][0] -= max(f[t][1], f[t][0]);
g[p[t]][1] -= f[t][0];
update(t);
g[p[t]][0] += max(f[t][1], f[t][0]);
g[p[t]][1] += f[t][0];
modify(1, 1, n, id[p[t]]);
t = top[p[t]];
}
update(t);
printf("%d\n", max(f[1][0], f[1][1]));
}
return 0;
}