Tree
题意:
给定n个节点的一棵树,求出包含每个点的连通点集的数量。
思路:
首先分析题意,可以发现满足递推关系,可树形DP求解。
f[i] :以i为根的所有子树且必选i号点的连通点集的所有合法集合。
属性:cnt
针对当前的节点u,其子节点为son且不考虑其他节点。
$$f[i] = \prod \limits_{son}(f[son] + 1)$$
up[i]: 以i为根的所有子树外且必选i号点的联通点集的所有合法集合
属性: cnt
假设当前节点为sta,其子节点为son
$$up[son] = (up[sta] * \frac{f[sta]}{f[son] + 1}) + 1 $$
小细节
因为涉及到了除法取模,且模数为质数。
(a / b) % mod = a * ksm(b, mod - 2) % mod
(仅当gcd(b, mod) == 1成立)
观察式子,可能出现 (f[son] + 1) % mod == 0,这时候就无法进行取模了,同样可以把直接暴力求解把式子改成乘法.(这种情况非常少,直接暴力求即可)
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <queue>
#include <deque>
#include <stack>
#include <bitset>
#include <unordered_map>
#define IOS ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define eb push_back()
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int maxn = 1e6 + 10;
typedef pair <int, int> PII;
int h[maxn], ne[maxn * 2], e[maxn * 2], idx, pre[maxn];
ll f[maxn], up[maxn];
const ll mod = 1e9 + 7;
ll get_mod1(ll a, ll b)
{
return (a % mod + b % mod) % mod;
}
ll get_mod2(ll a, ll b)
{
return (a % mod * (b % mod)) % mod;
}
ll ksm(ll base, ll power)
{
ll res = 1;
while(power)
{
if(power & 1)
res = get_mod2(res, base);
base = get_mod2(base, base);
power >>= 1;
}
return res;
}
void add(int u, int v)
{
e[idx] = v;
ne[idx] = h[u];
h[u] = idx ++;
}
void dfs1(int sta, int fa)
{
f[sta] = 1, pre[sta] = fa;
for(int i = h[sta] ; i != -1 ; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs1(son, sta);
f[sta] = get_mod2(f[sta], f[son] + 1);
}
}
void dfs2(int sta, int fa)
{
if((f[sta] + 1ll) % mod == 0) //up[sta]
{
ll temp = 1;
for(int i = h[fa] ; i != -1 ; i = ne[i])
{
int son = e[i];
if(son == pre[fa] || son == sta) continue;
temp = get_mod2(temp, f[son] + 1ll);
}
up[sta] = get_mod1(get_mod2(up[fa], temp), 1ll);
}
else
up[sta] = get_mod1(get_mod2(up[fa], get_mod2(f[fa], ksm(f[sta] + 1ll, mod - 2))), 1ll);
for(int i = h[sta] ; i != -1 ; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs2(son, sta);
}
}
int main()
{
int n;
scanf("%d", &n);
memset(h, -1, sizeof(h)), idx = 0;
for(int i = 1 ; i < n ; i ++)
{
int u, v;
scanf("%d %d", &u, &v);
add(u, v), add(v, u);
}
dfs1(1, 1);
dfs2(1, 1);
for(int i = 1 ; i <= n ; i ++)
{
printf("%lld\n", get_mod2(f[i], up[i]));
}
return 0;
}