树上差分(LCA辅助(树上倍增版))
题面回顾
有一棵$n$个节点,边权为$1$的树,有$m$个人在$S[i]$和$T[i]$之间的最短路径上从$0$时刻出发开始跑步,每个节点$i$上都有一个观察员,观察第$w[i]$秒恰好经过该节点的人数。问:每个观察员可以看到多少人?
思路分析
(解释一波lyd大佬的思路)
由于每个人运动的路径是固定的,所以可以把每个人的路径分为两段:起点$s$到$lca$,$lca$到终点$t$
因此,对于每个观测点$x$,有两种情况可以观察到第$i$个人(其中$d$数组表示深度,可以用bfs一遍处理倍增时进行预处理)
-
$x$在$s[i]$到$lca(s[i],t[i])$的路径上,$w[x]=d[s[i]]-d[x]$
-
$x$在$lca(s[i],t[i])$到$t[i]$的路径上,$w[x]=d[s[i]]+d[x]-2* d[lca(s[i],t[i])] $
这两种情况其实是类似的,我们不妨先来考虑第一种
移项得 $w[x]+d[x]=d[s[i]]$
这就相当于在$s[i]$到$lca(s[i],t[i])$的路径上每个节点放一个类型为$d[s[i]]$的物品,求最终每个节点上放有多少个类型为$d[x]+w[x]$的物品
这让我们想到了雨天的尾巴,但是那道线段树合并模板题是因为要求每个节点最多的种类是什么,这道题则是求给定的种类有多少个物品
因此可以不用敲线段树合并。。。
下面是做法:
-
对每个节点建立两个$vector$数组,分别表示在这个节点产生或消失的物品的类型
-
建立全局数组$c$,记录每种类型物品的数量(相当于桶)
-
开始做dfs,当遇到一个点$x$时,先保存当前$c[w[x]+d[x]]$的值,记为$cnt$
-
然后扫描$x$的两个$vector$,产生数组中出现$z$就让$c[z]++$,消失数组中出现$z$就让$c[z]–$。这是因为如果这个物品在$x$的子树中即产生又消失,那么对$x$无影响;如果只产生不消失,说明消失的地方在$x$上面,那么$x$上就会放下这个物品;由于规定下产生,上消失,所以不可能出现消失了但没有产生的情况(
你在无中生有暗度陈仓。。。) -
然后扫完子节点返回$x$,用当前的$c[w[x]+d[x]]$减去扫描前的$c[w[x]+d[x]]$(也就是存好的$cnt$),记为$x$处的$w[x]+d[x]$类物品的数量
下面是几个栗子
然后刚才说过的第二种情况也差不多,t[i]到$lca$路径上增加的物品类型为$d[s[i]]-2* d[lca(s[i],t[i])]$, 最后统计每个结点$x$处物品$w[x]-d[x]$的数量
最后就可以上代码了(一些细节会在代码里讲)
代码来了
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define fr(i,a,b) for(int i=a;i<=b;i++)
#define fo(i,a,b) for(int i=a;i>=b;i--)//个人习惯而已,别太在意
#pragma GCC optimize(2)
const int M=300005;
int n,m,v[M],p,f[M][20],s[M],t[M],lca[M],w[M],c1[2*M],c2[2*M],ans[M],d[M];
int head[M],Next[2*M],ver[2*M],tot;
vector<int> ap1[M],dap1[M],ap2[M],dap2[M];
//ap即appear出现,dap即disappear消失,1和2分别是从s到lca和t到lca
queue<int> q;
void add(int x,int y) {
Next[++tot]=head[x],head[x]=tot,ver[tot]=y;
}
void bfs() {
q.push(1); d[1]=1;
while(q.size()) {
int x=q.front(); q.pop();
for(int i=head[x];i;i=Next[i]) {
int y=ver[i];
if(d[y]) continue;
d[y]=d[x]+1;
f[y][0]=x;
fr(j,1,p) f[y][j]=f[f[y][j-1]][j-1];
//LCA的倍增预处理(这个不用我多说了吧)
q.push(y);
}
}
}
int LCA(int x,int y) {
if(d[x]>d[y]) swap(x,y);
fo(i,p,0) if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
fo(i,p,0) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void dfs(int x) {
v[x]=1;
int cnt1=c1[w[x]+d[x]],cnt2=c2[w[x]-d[x]+M];
//cnt2加这个M是因为w[x]-d[x]有可能是负的,要加偏移值或离散化
//先保存好扫描前的数量
for(int i=0;i<ap1[x].size();i++) c1[ap1[x][i]]++;
for(int i=0;i<dap1[x].size();i++) c1[dap1[x][i]]--;
for(int i=0;i<ap2[x].size();i++) c2[ap2[x][i]+M]++;
//当然后面的也要加偏移值
for(int i=0;i<dap2[x].size();i++) c2[dap2[x][i]+M]--;
//其实就是个差分吧
for(int i=head[x];i;i=Next[i]) {
int y=ver[i];
if(v[y]) continue;
dfs(y);
}
ans[x]=c1[w[x]+d[x]]-cnt1+c2[w[x]-d[x]+M]-cnt2;
//相减即为答案
//容易发现从两个端点到lca是互不干扰的
}
int main(){
std::ios::sync_with_stdio(false);
scanf("%d%d",&n,&m);
p=(int)log(n)/log(2)+1;
fr(i,1,n-1) {
int x,y;
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
bfs();
fr(i,1,n) scanf("%d",&w[i]);
fr(i,1,m) {
int s,t;
scanf("%d%d",&s,&t);
int lca=LCA(s,t);
ap1[s].push_back(d[s]);
dap1[f[lca][0]].push_back(d[s]);
ap2[t].push_back(d[s]-2*d[lca]);
dap2[lca].push_back(d[s]-2*d[lca]);
//记录物品出现和消失的类型
}
dfs(1);
fr(i,1,n) printf("%d ",ans[i]);
return 0;
}