试题 C: 蚂蚁开会
时间限制: 1.0s 内存限制: 256.0MB 本题总分:10 分
错误思路:
三层循环,遍历所有可能的点,每个点逐个判断是否有超过两个线段经过,然后记录个数。
时间复杂度O(mx*my*n)
, 最坏为5 × 10^10
#include<iostream>
#include<bits/stdc++.h>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 505;
int n, m;
double ux[N], uy[N], vx[N], vy[N];
//检查点(x,y)是否在第i条线段上
bool check(int x, int y, int i)
{
int mmx = max(ux[i], vx[i]);
int mix = min(ux[i], vx[i]);
int mmy = max(uy[i], vy[i]);
int miy = min(uy[i], vy[i]);
if (x < mix || x > mmx || y < miy || y > mmy) return false;
return (uy[i] - y) * (vx[i] - x) == (vy[i] - y) * (ux[i] - x);
}
void solve()
{
cin >> n;
int mx = 0, my = 0; //记录范围
for (int i = 0; i < n; i++)
{
cin >> ux[i] >> uy[i] >> vx[i] >> vy[i];
if (mx < ux[i]) mx = ux[i];
if (mx < vx[i]) mx = vx[i];
if (my < uy[i]) my = uy[i];
if (my < vy[i]) my = vy[i];
}
// 暴力遍历判断
int ans = 0;
for (int i = 0; i <= mx; i++) {
for (int j = 0; j <= my; j++) {
int cnt = 0;
for (int k = 0; k < n; k++) {
if (check(i, j, k)) {
cnt++;
}
if (cnt >= 2) {
ans++;
break;
}
}
}
}
cout << ans << endl;
}
int main()
{
solve();
return 0;
}
优化思路:
不再遍历所有可能点,而是遍历每条线段上的整数点(关键),将其出现的次数存到哈希表中,再遍历哈希表统计交点的数量。
时间复杂度最坏为O(n*max(x,y)), 5×10^7
#include<iostream>
#include<cstring>
#include<algorithm>
#include<map>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 505;
int n, m;
int ux[N], uy[N], vx[N], vy[N];
map<PII, int> cnt_mp;
int gcd(int a, int b)
{
return b ? gcd(b, a % b) : a; //当b为0时,a和0的最大公约数即为a,故返回a
}
// 记录每条线段经过的整数点
void count(int i)
{
int x1 = ux[i], x2 = vx[i];
int y1 = uy[i], y2 = vy[i];
int dx = x2 - x1, dy = y2 - y1;
int d = gcd(abs(dx), abs(dy));
dx = dx / d, dy = dy / d;
for (int i = 0; ; i++) {
int x = x1 + i * dx, y = y1 + i * dy;
cnt_mp[{x, y}]++; // 第i条线段经历的所有整数点的次数存到 map 中
if (x == x2 && y == y2) break;
}
}
void solve()
{
cin >> n;
int mx = 0, my = 0;
for (int i = 0; i < n; i++) {
cin >> ux[i] >> uy[i] >> vx[i] >> vy[i];
}
// 遍历每条线段经过的整数点
for (int i = 0; i < n; i++) {
count(i);
}
// 遍历 map 统计交点数量
int ans = 0;
for (map<PII, int>::iterator it = cnt_mp.begin(); it != cnt_mp.end(); it++) {
if (it->second >= 2) ans++;
}
// 输出答案
cout << ans << endl;
}
int main()
{
solve();
return 0;
}