题面
题目描述
给定一棵树,树中包含 $n$ 个节点(编号 $1~n$ ),其中第 i 个节点的权值为 $a_i$ 。
初始时,$1$ 号节点为树的根节点。
现在要对该树进行 $m$ 次操作,操作分为以下 $4$ 种类型:
-
1 u v k
,修改路径上节点权值,将节点 $u$ 和节点 $v$ 之间路径上的所有节点(包括这两个节点)的权值增加 $k$ 。 -
2 u k
,修改子树上节点权值,将以节点 $u$ 为根的子树上的所有节点的权值增加 $k$ 。 -
3 u v
,询问路径,询问节点 $u$ 和节点 $v$ 之间路径上的所有节点(包括这两个节点)的权值和。 -
4 u
,询问子树,询问以节点 $u$ 为根的子树上的所有节点的权值和。
样例
$input$
20
42 7 39 12 41 11 5 37 35 3 10 32 26 5 49 21 46 44 28 11
5 11
19 20
14 7
6 5
19 2
12 4
2 3
14 10
8 20
16 2
9 1
1 6
15 13
15 6
4 16
5 14
1 19
17 14
18 1
30
4 13
1 1 4 41
2 14 43
2 14 38
1 7 10 3
4 4
1 5 3 27
1 15 13 24
1 9 14 35
3 15 18
3 2 20
1 5 7 13
3 6 13
4 3
1 7 8 27
2 10 5
3 18 5
4 10
4 17
3 13 7
3 7 13
2 8 17
4 14
3 7 18
3 14 5
2 12 25
1 19 18 9
4 10
3 6 9
1 8 15 34
$output$
26
85
335
182
196
66
459
92
127
659
659
512
752
307
92
351
样例图
算法
首先题目在一棵树上操作,应考虑怎样存一棵树,可以使用邻接表存图。
其次题目中每种操作都属于对多个值进行修改与查询,尝试转换成区间操作进行处理。
区间修改与区间查询,容易想到线段树维护。
考虑将题目进行分块处理,发现可以将操作分为两种:
- 链操作
- 子树操作
子树操作
比起链操作,子树操作相对简单,所以我们先处理对子树的操作。
我们已经准备将所有操作转换成对区间的操作,那么尝试将一个子树转换成一个连续的区间。
思考发现 $\text{dfs}$ 序中一个子树的所在位置是连续的,所以决定使用 $\text{dfs}$ 序将整颗树转换成一段段区间。
这样就将子树操作转换成了区间操作。
链操作
接下来处理链。
我们发现如果我们将树直接通过 $\text{dfs}$ 序转换成区间,我们所要的操作的链会散开,逐一操作时间复杂度爆炸。
所以我们需要引出本题解的重头戏——树链剖分:
树链剖分
我们已经将树上操作转化成区间操作了,所以思考如何将链也转化成一个个区间,这就需要用到树链剖分了。
我们可以发现,在 $\text{dfs}$ 序中,也存在一些链是连续的,所以我们可以思考怎样利用这些链。
可以将每一条链存起来,如果直接进行操作,会快一点,但是依然不够,所以要继续优化。
因为我们直接对整条连续链操作,所以尝试让操作次数更多的链长度更长,发现链的尾部子树越大,操作次数可能越多。
所以让连续的链尽量向子树大的儿子延伸,是最优的。
树链剖分的思想就是将儿子分成两类:重儿子和轻儿子。
重儿子就是子树最大的儿子,其余的则是轻儿子。
令每个点向重儿子连边,形成的一条条链就是重链,使重链在 $\text{dfs}$ 序里是连续的区间,这样时间复杂度就会大大减小。
树链剖分前后的树分别是这样的(样例图,加粗的是重链):
实现方法就是在 $\text{dfs}$ 过程中优先遍历重儿子,使 $\text{dfs}$ 序中重儿子总是与其父亲节点相邻的。
剖分完后,除每棵子树外,每条重链上的点也是连续的,如图(每个蓝色框内的点在 $\text{dfs}$ 序中连续):
在修改过程中,沿着重链向上爬,边爬边对连续的区间操作,直到两个链的端点向上爬到它们的最近公共祖先。
注:记得开long long
。
$\text{AC}$ 完整代码
光说不练不是真把式,上代码:
大量使用 $\text{STL}$ 版(需吸氧)
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0, f = 0; char c = getchar();
for (; !isdigit(c); c = getchar()) f |= c == '-';
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c & 15);
return f ? -x : x;
}
#define fi first
#define se second
typedef long long LL;
typedef pair<LL, LL> PLL;
const int N = 100010;
int n, dfnum;
array<int, N> a, fa, sn, sz, tp, dep, bg, ed;
array<vector<int>, N> g;
array<PLL, N << 2> tr;
inline void add(int p, int L, int R, int d) {
tr[p].fi += 1LL * d * (R - L + 1);
tr[p].se += d;
}
inline void down(int p, int L, int R) {
if (L == R) return;
int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
add(ls, L, mid, tr[p].se);
add(rs, mid + 1, R, tr[p].se);
tr[p].se = 0;
}
inline void modify(int p, int L, int R, int l, int r, int d) {
if (L >= l && R <= r) {
add(p, L, R, d);
return;
}
down(p, L, R); int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
if (l <= mid) modify(ls, L, mid, l, r, d);
if (r > mid) modify(rs, mid + 1, R, l, r, d);
tr[p].fi = tr[ls].fi + tr[rs].fi;
}
inline LL query(int p, int L, int R, int l, int r) {
if (L >= l && R <= r) return tr[p].fi;
down(p, L, R); int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
return (l <= mid ? query(ls, L, mid, l, r) : 0) + (r > mid ? query(rs, mid + 1, R, l, r) : 0);
}
inline void modify(int l, int r, int d) {
modify(1, 1, n, l, r, d);
}
inline LL query(int l, int r) {
return query(1, 1, n, l, r);
}
inline void dfs(int x) {
sz[x] = 1;
for (auto y : g[x]) {
if (y == fa[x]) continue;
fa[y] = x; dep[y] = dep[x] + 1;
dfs(y);
sz[x] += sz[y];
if (sz[y] > sz[sn[x]]) sn[x] = y;
}
}
inline void dfs(int x, int top) {
bg[x] = ++ dfnum;
tp[x] = top;
if (sn[x]) dfs(sn[x], top);
for (auto y : g[x]) {
if (y == fa[x] || y == sn[x]) continue;
dfs(y, y);
}
ed[x] = dfnum;
}
inline void modifyLink(int l, int r, int d) {
while (tp[l] != tp[r]) {
if (dep[tp[l]] < dep[tp[r]]) swap(l, r);
modify(bg[tp[l]], bg[l], d);
l = fa[tp[l]];
}
if (bg[l] > bg[r]) swap(l, r);
modify(bg[l], bg[r], d);
}
inline LL queryLink(int l, int r) {
LL res = 0;
while (tp[l] != tp[r]) {
if (dep[tp[l]] < dep[tp[r]]) swap(l, r);
res += query(bg[tp[l]], bg[l]);
l = fa[tp[l]];
}
if (bg[l] > bg[r]) swap(l, r);
return res + query(bg[l], bg[r]);
}
int main() {
n = read();
for (int i = 1; i <= n; ++ i) a[i] = read();
for (int i = 1; i < n; ++ i) {
int x = read(), y = read();
g[x].emplace_back(y);
g[y].emplace_back(x);
}
dfs(1); dfs(1, 1);
for (int i = 1; i <= n; ++ i) modify(bg[i], bg[i], a[i]);
for (int m = read(); m --; ) {
int l, r, d;
switch(read()) {
case 1:
l = read(); r = read(); d = read();
modifyLink(l, r, d);
break;
case 2:
l = read(); d = read();
modify(bg[l], ed[l], d);
break;
case 3:
l = read(); r = read();
printf("%lld\n", queryLink(l, r));
break;
case 4:
l = read();
printf("%lld\n", query(bg[l], ed[l]));
break;
}
}
return 0;
}
正常版
#include<bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0, f = 0; char c = getchar();
for (; !isdigit(c); c = getchar()) f |= c == '-';
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c & 15);
return f ? -x : x;
}
typedef long long LL;
const int N = 100010;
struct Edge {
int y, nxt;
Edge(int y = 0, int nxt = -1)
: y(y), nxt(nxt) {}
};
struct Seg {
LL v, f;
Seg(LL v = 0, LL f = 0)
: v(v), f(f) {}
};
int n, edges, dfnum;
int a[N], lnk[N], fa[N], sn[N], sz[N], tp[N], dep[N], bg[N], ed[N];
Edge edge[N << 1];
Seg tr[N << 2];
inline void addEdge(int x, int y) {
edge[edges] = Edge(y, lnk[x]);
lnk[x] = edges ++;
}
inline void add(int p, int L, int R, int d) {
tr[p].v += 1LL * d * (R - L + 1);
tr[p].f += d;
}
inline void down(int p, int L, int R) {
if (L == R) return;
int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
add(ls, L, mid, tr[p].f);
add(rs, mid + 1, R, tr[p].f);
tr[p].f = 0;
}
inline void modify(int p, int L, int R, int l, int r, int d) {
if (L >= l && R <= r) {
add(p, L, R, d);
return;
}
down(p, L, R); int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
if (l <= mid) modify(ls, L, mid, l, r, d);
if (r > mid) modify(rs, mid + 1, R, l, r, d);
tr[p].v = tr[ls].v + tr[rs].v;
}
inline LL query(int p, int L, int R, int l, int r) {
if (L >= l && R <= r) return tr[p].v;
down(p, L, R); int mid = L + R >> 1, ls = p << 1, rs = ls | 1;
return (l <= mid ? query(ls, L, mid, l, r) : 0) + (r > mid ? query(rs, mid + 1, R, l, r) : 0);
}
inline void modify(int l, int r, int d) {
modify(1, 1, n, l, r, d);
}
inline LL query(int l, int r) {
return query(1, 1, n, l, r);
}
inline void dfs(int x) {
sz[x] = 1;
for (int e = lnk[x]; ~e; e = edge[e].nxt) {
int y = edge[e].y;
if (y == fa[x]) continue;
fa[y] = x; dep[y] = dep[x] + 1;
dfs(y);
sz[x] += sz[y];
if (sz[y] > sz[sn[x]]) sn[x] = y;
}
}
inline void dfs(int x, int top) {
bg[x] = ++ dfnum;
tp[x] = top;
if (sn[x]) dfs(sn[x], top);
for (int e = lnk[x]; ~e; e = edge[e].nxt) {
int y = edge[e].y;
if (y == fa[x] || y == sn[x]) continue;
dfs(y, y);
}
ed[x] = dfnum;
}
inline void modifyLink(int l, int r, int d) {
while (tp[l] != tp[r]) {
if (dep[tp[l]] < dep[tp[r]]) swap(l, r);
modify(bg[tp[l]], bg[l], d);
l = fa[tp[l]];
}
if (bg[l] > bg[r]) swap(l, r);
modify(bg[l], bg[r], d);
}
inline LL queryLink(int l, int r) {
LL res = 0;
while (tp[l] != tp[r]) {
if (dep[tp[l]] < dep[tp[r]]) swap(l, r);
res += query(bg[tp[l]], bg[l]);
l = fa[tp[l]];
}
if (bg[l] > bg[r]) swap(l, r);
return res + query(bg[l], bg[r]);
}
int main() {
n = read();
memset(lnk + 1, -1, sizeof(int) * n);
for (int i = 1; i <= n; ++ i) a[i] = read();
for (int i = 1; i < n; ++ i) {
int x = read(), y = read();
addEdge(x, y);
addEdge(y, x);
}
dfs(1); dfs(1, 1);
for (int i = 1; i <= n; ++ i) modify(bg[i], bg[i], a[i]);
for (int m = read(); m --; ) {
int l, r, d;
switch(read()) {
case 1:
l = read(); r = read(); d = read();
modifyLink(l, r, d);
break;
case 2:
l = read(); d = read();
modify(bg[l], ed[l], d);
break;
case 3:
l = read(); r = read();
printf("%lld\n", queryLink(l, r));
break;
case 4:
l = read();
printf("%lld\n", query(bg[l], ed[l]));
break;
}
}
return 0;
}
%%%
### big old
# big old ! %%%
# orz OTL OTZ
%%%太强了