原题链接: 小猪佩奇爬树
题意:
给出一个 $n$ 个点的无根树,每个点有一个颜色 $w_i (w_i \leq n)$
对于$1-n$每种颜色,求有多少种不同的长度$>1$的路径能覆盖所有该颜色的点。
这破题写死我了
分类讨论:
1.该颜色的点数为 $0$,方案数为树上所有的路径总数 $n(n - 1)/2$
2.该颜色的点数为 $1$,方案数为所有经过该点的路径总数
记该点分割的某个连通块的点的个数为 $f_i$
答案为 $ n - 1 + \frac{1}{2}\sum f_i(n - f_i - 1)$
3.该颜色的点数$\geq2$,且所有点都是深度最深的点的祖先结点
方案数分该链的两个端点分割的两个连通块的点的个数相乘
4.该颜色的点数$\geq2$,且所有点包含于一个 $\bigwedge$ 形状的链中
方案数和3同理,但计算方法不同
5.该颜色的点数$\geq2$,不存在任何一个路径覆盖所有的点,方案数为 $0$
时间复杂度 $O(nlogn)$
C++代码
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
typedef long long LL;
const int N = 1000010, M = N * 2;
int h[N], e[M], ne[M], idx;
void add(int a, int b)
{
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
int n, m;
int w[N];
vector<int> col[N];
int depth[N], fa[N][18];
int f[N];
void dfs(int u, int last)
{
f[u] = 1;
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == last) continue;
depth[j] = depth[u] + 1;
fa[j][0] = u;
for(int k = 1; k < 18; k ++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
dfs(j, u);
f[u] += f[j];
}
}
int lca(int a, int b)
{
if(depth[a] < depth[b]) swap(a, b);
for(int i = 17; i >= 0; i --)
if(depth[fa[a][i]] >= depth[b])
a = fa[a][i];
if(a == b) return a;
for(int i = 17; i >= 0; i --)
if(fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
return fa[a][0];
}
bool cmp(int a, int b)
{
return depth[a] > depth[b];
}
bool check(int u, int a, int b) // u是否在a - b路径上
{
return depth[u] >= depth[lca(a, b)] and (lca(a, u) == u or lca(b, u) == u);
}
int main()
{
memset(h, -1, sizeof h);
scanf("%d", &n);
for(int i = 1; i <= n; i ++)
{
scanf("%d", &w[i]);
col[w[i]].push_back(i);
}
for(int i = 1; i < n; i ++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
depth[1] = 1;
dfs(1, -1);
for(int i = 1; i <= n; i ++)
{
LL res = 0;
vector<int> &v = col[i];
if(!v.size()) res = (LL)n * (n - 1) / 2;
else if(v.size() == 1) {
int u = v[0];
for(int j = h[u]; ~j; j = ne[j])
{
int k = e[j];
if(f[k] > f[u]) res += (LL)(n - f[u]) * (f[u] - 1);
else res += (LL)f[k] * (n - f[k] - 1);
}
res /= 2;
res += n - 1;
}
else
{
sort(v.begin(), v.end(), cmp);
int flag = 1; // 0不在一条链上,1所有点都是v[0]的祖先节点,2所有点在一条^型的链上
int a = v[0], b = 0;
for(int i : v)
{
int x = lca(a, i);
if(!b)
{
if(x != i) // i不是a的祖先节点
{
b = i; // b是[不是a的祖先节点的深度最深的节点]
flag = 2;
}
}
else if(!check(i, a, b)) // i不在a - b的路径上
{
flag = 0;
break;
}
}
if(flag == 1)
{
b = v.back(); // b是深度最小的节点
for(int i = h[b]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[b] and lca(a, j) == j) // j是a所在子树的祖先节点
{
b = j;
break;
}
}
res = (LL)f[a] * (n - f[b]);
}
else if(flag == 2)
res = (LL)f[a] * f[b];
}
printf("%lld\n", res);
}
return 0;
}