朴素版Prim
时间复杂度:$O(n^{2})$
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <queue>
#define rint register int
#define ll long long
using namespace std;
constexpr int N = 510, M = 2e4 + 10;
int graph[N][N];
int dist[N];
bool state[N];
typedef pair<int, int> pii;
pii prim(int n){
memset(state, false, sizeof state);
memset(dist, 0x3f, sizeof dist);
dist[1] = 0;
int sum = 0;
for(rint k = 0; k < n; k ++ ){
int nx = -1;
for(rint i = 1; i <= n; i ++ ){
if(state[i]) continue;
if(nx == -1 || (dist[nx] > dist[i])) nx = i;
}
state[nx] = true;
if(dist[nx] == 0x3f3f3f3f) return {0, 0};
sum += dist[nx];
for(rint i = 1; i <= n; i ++ ){
if(state[i]) continue;
dist[i] = min(dist[i], graph[nx][i]);
}
}
return {1, sum};
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m; cin >> n >> m;
memset(graph, 0x3f, sizeof graph);
while(m -- ){
int a, b, c; cin >> a >> b >> c;
graph[a][b] = graph[b][a] = min(graph[a][b], c);
}
pii res = prim(n);
if(!res.first) cout << "impossible" << endl;
else cout << res.second << endl;
return 0;
}
堆优化版Prim
时间复杂度:$O(mlog(m))$
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <queue>
#define rint register int
#define ll long long
using namespace std;
constexpr int N = 510, M = 2e5 + 10;
int h[N], e[M], ne[M], w[M], idx;
typedef pair<int, int> pii;
void add(int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
bool state[N];
int dist[N];
pii heap_prim(int n){
priority_queue<pii, vector<pii>, greater<pii>> qu;
memset(dist, 0x3f, sizeof dist);
memset(state, false, sizeof state);
qu.push({0, 1});
dist[1] = 0;
while(!qu.empty()){
pii cur = qu.top();
qu.pop();
int ver = cur.second;
if(state[ver]) continue;
state[ver] = true;
for(rint i = h[ver]; ~i; i = ne[i]){
int j = e[i], distance = w[i];
if(state[j]) continue;
if(dist[j] > distance){
dist[j] = distance;
qu.push({dist[j], j});
}
}
}
int sum = 0;
for(rint i = 1; i <= n; i ++ ){
if(dist[i] == 0x3f3f3f3f) return {0, 0};
sum += dist[i];
}
return {1, sum};
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m; cin >> n >> m;
memset(h, -1, sizeof h);
while(m -- ){
int a, b, c; cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
pii res = heap_prim(n);
if(res.first) cout << res.second << endl;
else cout << "impossible" << endl;
return 0;
}