前言:今天刷题时遇到了树链剖分的题目所以…就学了树链剖分(点个赞吧❤️)
树链剖分
定义:
一种思想
1. 将一颗树转换成一个序列
2. 把一棵树中任意一条路径转换成不超过logn 段连续区间
概念:
1. 重儿子和轻儿子
先计算出每个节点的子树的节点总数
void dfs(int u, int p)
{
s[u] = 1;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != p)
{
dfs(j, u);
s[u] += s[j];
}
}
}
若当前节点为u, 左儿子为l, 右儿子为r, 若s[l] > s[r]则l为重儿子, r为轻儿子(特别的根节点是轻儿子)
2. 重边, 轻边: 重儿子和它的父亲节点之间的边就是重边, 同理轻儿子和它的父亲节点之间的边就是轻边.
3. 重链:
由重边构成的路径被称为重链
每个点都要在一个重链里, 若当前这个点不直接在一个重链里, 找以这个点开头往下走的第一个重链里, 特别的下面没有重链的节点单独为一个重链
性质:
树中任意一条路径均可拆分成logn条重链, 即可拆分成logn个连续区间
重链的开头一定是一个轻儿子
操作:
主要用dfs序把树(优先遍历重儿子)转换成一个序列
为什么要优先遍历重儿子呢? 这样可以保证重链上所有点的编号都是连续的
// Problem: 树链剖分
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/2570/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, M = 2e5 + 10;
ll w[N], nw[N];
int h[N], e[M], ne[M], idx;
int id[N], timestamp;
int dep[N], s[N];
int top[N];
int fa[N], son[N];
int n, m;
struct Segment
{
int l, r;
ll add, sum;
} tr[N << 2];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs1(int u, int p, int depth)
{
s[u] = 1;
dep[u] = depth;
fa[u] = p;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != p)
{
dfs1(j, u, depth + 1);
s[u] += s[j];
if (s[son[u]] < s[j]) son[u] = j;
}
}
}
void dfs2(int u, int t)
{
id[u] = ++timestamp;
nw[timestamp] = w[u];
top[u] = t;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != fa[u] && j != son[u]) dfs2(j, j);
}
}
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add, left.sum += root.add * (left.r - left.l + 1);
right.add += root.add, right.sum += root.add * (right.r - right.l + 1);
root.add = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = {l, r, 0, nw[r]};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k)
{
if (l <= tr[u].l && r >= tr[u].r)
{
tr[u].add += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) update(u << 1, l, r, k);
if (r > mid) update(u << 1 | 1, l, r, k);
pushup(u);
}
ll query(int u, int l, int r)
{
if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if (l <= mid) res += query(u << 1, l, r);
if (r > mid) res += query(u << 1 | 1, l, r);
return res;
}
void update_path(int u, int v, int k)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
update(1, id[top[u]], id[u], k);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
update(1, id[v], id[u], k);
}
ll query_path(int u, int v)
{
ll res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res += query(1, id[v], id[u]);
return res;
}
void update_tree(int u, int k)
{
update(1, id[u], id[u] + s[u] - 1, k);
}
ll query_tree(int u)
{
return query(1, id[u], id[u] + s[u] - 1);
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) cin >> w[i];
memset(h, -1, sizeof h);
int a, b;
for (int i = 1; i < n; i++)
{
cin >> a >> b;
add(a, b);
add(b, a);
}
dfs1(1, 0, 1);
dfs2(1, 1);
build(1, 1, n);
cin >> m;
int t, u, v, k;
while (m--)
{
cin >> t >> u;
if (t == 1)
{
cin >> v >> k;
update_path(u, v, k);
}
else if (t == 2)
{
cin >> k;
update_tree(u, k);
}
else if (t == 3)
{
cin >> v;
cout << query_path(u, v) << '\n';
}
else cout << query_tree(u) << '\n';
}
return 0;
}
巨佬可以教教我么QwQ
我很菜的(你直接去报y总的算法基础和提高课啊)哦哦哦