每个顶点可以分为m块,向左边取k条路径,向右边取m-k条路径,循环m次,求出顶点u在向下取m条连通的边的最大值,并记录于
f[u][m]中,最后返回f[u][m],dfs采用的是记忆化搜索,不然会重复计算并且超时。
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=300;
int head[N],ne[N],w[N],e[N],id=1;
int n;
int res=0;
int f[N][N];
void add(int a,int b,int c)
{
e[id]=b;
w[id]=c;
ne[id]=head[a];
head[a]=id++;
}
int dfs(int u,int m,int father)
{
if(m==0)
return 0;
if(m<0)
return -0x3f3f3f3f;
if(f[u][m])
return f[u][m];
int dist=0,j1=0,j2=0;
for(int i=head[u];i!=-1;i=ne[i])
{
if(e[i]==father)
continue;
if(!j1)
j1=i;
else if(!j2)
j2=i;
}
if(!j1 && !j2)
{
return -0x3f3f3f3f;
}
for(int k=0;k<=m;k++)//k=0,左边一条边都不选,k=1,左边选一条边
{
int i1=e[j1],i2=e[j2];
int d1=dfs(i1,k-1,u)+w[j1];
int d2=dfs(i2,m-k-1,u)+w[j2];
if(d1<0)
d1=0;
if(d2<0)
d2=0;
dist=max(dist,d1+d2);
f[u][m]=max(f[u][m],dist);
}
res=max(res,dist);
return dist;
}
int main()
{
int m;
cin>>n>>m;
memset(head,-1,sizeof head);
memset(ne,-1,sizeof ne);
memset(f,0,sizeof f);
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
dfs(1,m,-1);
cout<<res<<endl;
}