一般情况下可以使用长链剖分来优化的 DP 会有一维状态为深度维。
我们可以考虑使用长链剖分优化树上 DP。
具体的,我们每个节点的状态直接继承其重儿子的节点状态,同时将轻儿子的 DP 状态暴力合并。
CF1009F
https://codeforces.com/contest/1009/problem/F
题意:设$d(u,k)$是u节点的k-son的数量,对每个节点,找到k使得$d(u,k)$最大,如果有多个最大的,取最小的那个k。
很自然的想到dp状态:$dp[i][j]=i节点子树内,深度往下j的节点的个数$。这样状态转移的话:$$对每个j,dp[i][j] = \sum dp[son][j-1]$$那么时间复杂度就是$O(n^2)$的
我们考虑每次转移我们直接继承重儿子的 DP数组和答案,并且考虑在此基础上进行更新。
首先我们需要将重儿子的 DP 数组前面插入一个元素 1, 这代表着当前节点。
然后我们将所有轻儿子的 DP 数组暴力和当前节点的 DP 数组合并。注意到因为轻儿子的 DP 数组长度为轻儿子所在重链长度,而所有重链长度和为 n。也就是说,我们直接暴力合并轻儿子的总时间复杂度为 O(n)。
总时间复杂度$O(n(遍历所有节点) + n(暴力合并所有轻儿子))$
btw,这里用指针的实现方式有点妙。每个节点的子树在自己的dfn序范围内使用同一个数组val,兄弟节点之间互不影响,往上传之后整个dfn序范围的空间就会保存父节点信息。
ps.关于直接继承重子节点状态的问题,可以看到代码中并没有合并重子节点的过程,而此时直接访问$dp[u][mxp[u]]$能得到正确答案,因为我们指针实现实际上已经帮我们完成了继承,这意味着此时直接访问dp[u]就能得到重子节点的信息(访问dp[u][0]则需要手动插入)。
原因如下:
可以看到求dfn序的时候,先遍历重子节点。因此节点u的dfn序为x,则节点u的重子节点dfn序为x+1,指针的实现方式直接帮我们完成了继承这个事情,越看越牛(
#include<bits/stdc++.h>
#define LOCAL//delete when submit!!!!!!
using namespace std;
using ll = long long;
using pii = pair<int,int>;
using pll = pair<ll,ll>;
const int N = 1e6+10;
int dep[N],fa[N],mx[N],son[N];
int top[N],len[N];
int dfn[N];
int d[N];//d是i节点能往下的最大次数
int *dp[N],val[N],mxp[N];//mxp是最大d(u,k)的k val是dp的实际内存空间
vector<vector<int>> g;
int n;
void dfs1(int u){// 第一次插入一个1
mx[u] = dep[u];son[u] = -1;
for(auto v:g[u]){
if(!dep[v]){
dep[v] = dep[u]+1;
fa[v] = u;
dfs1(v);
if(son[u]==-1||mx[v]>mx[son[u]]) mx[u] = mx[v],son[u] = v;
}
}
}
void dfs2(int u,int t){
top[u] = t;
dfn[u] = ++*dfn;//++*dfn:将 dfn[0] 自增 1,并将其新值作为当前 DFS 访问顺序的标记。
len[u] = mx[u] - dep[t] + 1;
d[u] = mx[u] - dep[u];
dp[u] = val + dfn[u];//申请val上从下标dfn[u]开始的内存
if(son[u]!=-1) dfs2(son[u],t);
for(auto v:g[u]){
if(v==son[u]||v==fa[u]) continue;
dfs2(v,v);
}
}
void getans(int u){// 暴力合并算答案
if(son[u]!=-1){
getans(son[u]);
mxp[u] = mxp[son[u]] + 1;
}
dp[u][0] = 1;
//如果从叶节点继承过来,那么此时mxp[1/0]都是1,要特判为0
if(dp[u][mxp[u]]<=1) mxp[u] = 0;
for(auto v:g[u]){
if(v==fa[u]||v==son[u]) continue;
getans(v);
for(int i=0;i<=d[v];i++){//注意这里是能往下的深度d[v]
dp[u][i+1] += dp[v][i];
if(dp[u][i+1]>dp[u][mxp[u]]) mxp[u] = i+1;
if(dp[u][i+1]==dp[u][mxp[u]]&& i+1<mxp[u]) mxp[u] = i+1;
}
}
}
void solve(){
cin>>n;
g = vector<vector<int>>(n+1);
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
dep[1] = 1;
dfs1(1);
dfs2(1,1);
getans(1);
for(int i=1;i<=n;i++) cout<<mxp[i]<<endl;
}
int main(){
#ifdef LOCAL
freopen("in.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
std::ios::sync_with_stdio(0);std::cout.tie(0);std::cin.tie(0);
int T = 1;
#ifdef MULTI_TEST
cin>>T;
#endif
while(T--){
solve();
}
return 0;
}