题意:给定一棵树,每个结点都有颜色,问有多少个连通块,使得连通块中的其中一个颜色的出现次数严格大于该连通块的一半。
解法:树上DP
我们一个颜色一个颜色解决即可,假设当前颜色为c,那么树上的结点就变成了1和0(要不就是颜色c,要不就不是),容易想出一个DP设计 $dp[i][j]$ 表示以i为根的子树中,和为j的连通块个数。答案就是所有j大于0的个数,这个问题可以转化为树上背包。
但是这个的复杂度是 $O(n^3)$,所以我们想办法优化一下。注意到枚举的时候,如果非当前色的个数过多可以提前结束枚举,考虑到这个性质,我们只需要枚举的时候定下范围即可,时间复杂度可以平摊到 $O(n^2)$。
#include <bits/stdc++.h>
#define endl '\n'
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL,LL> PLL;
const int INF = 0x3f3f3f3f, N = 3030;
const int MOD = 998244353;
const double eps = 1e-6;
const double PI = acos(-1);
inline int lowbit(int x) {return x & (-x);}
int a[N], cnt[N], sz[N];
int n;
LL dp[N][N << 1], f[N << 1];
LL res;
vector<int> g[N];
void dfs(int u, int last, int col) {
sz[u] = 1;
if (col == a[u]) dp[u][n + 1] = 1;
else dp[u][n - 1] = 1;
for (int v : g[u]) {
if (v == last) continue;
dfs(v, u, col);
//定义上下界,但凡范围大一点都会T
int bor1 = min(sz[u], cnt[col]), bor2 = min(sz[v], cnt[col]);
for (int i = -bor1; i <= bor1; i ++ ) f[i + n] = dp[u][i + n];
for (int i = -bor1; i <= bor1; i ++ ) {
for (int j = -bor2; j <= bor2; j ++ ) {
dp[u][i + j + n] += dp[v][j + n] * f[i + n] % MOD;
dp[u][i + j + n] %= MOD;
}
}
sz[u] += sz[v];
}
for (int i = 1; i <= cnt[col]; i ++ ) res = (res + dp[u][i + n]) % MOD;
}
inline void solve() {
cin >> n;
for (int i = 1; i <= n; i ++ ) {
cin >> a[i];
cnt[a[i]] ++;
}
for (int i = 1; i < n; i ++ ) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1; i <= n; i ++ ) {
if (cnt[i] == 0) continue;
for (int u = 1; u <= n; u ++ ) for (int j = -cnt[i]; j <= cnt[i]; j ++ )
dp[u][j + n] = 0;
dfs(1, -1, i);
}
cout << res << endl;
}
int main() {
#ifdef DEBUG
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
auto now = clock();
#endif
ios::sync_with_stdio(false), cin.tie(nullptr);
cout << fixed << setprecision(2);
// int T; cin >> T;
// while (T -- )
solve();
#ifdef DEBUG
cout << "============================" << endl;
cout << "Program run for " << (clock() - now) / (double)CLOCKS_PER_SEC * 1000 << " ms." << endl;
#endif
return 0;
}