题目描述
四平方和定理,又称为拉格朗日定理:
每个正整数都可以表示为至多 4 个正整数的平方和。
如果把 0 包括进去,就正好可以表示为 4 个数的平方和。
比如:
$5$$=$$0^2$+$0^2$+$1^2$+$2^2$
$7$$=$$1^2$+$1^2$+$1^2$+$2^2$
对于一个给定的正整数,可能存在多种平方和的表示法。
要求你对 4 个数排序:
$0≤a≤b≤c≤d$
并对所有的可能表示法按 a,b,c,d 为联合主键升序排列,最后输出第一个表示法。
输入样例
5
输出样例
0 0 1 2
算法1(暴力枚举) $O(n^3)$(y总加强了数据,这个过不了,下面哈希能够过)
这里我们可以稍微做点小小的优化,我们可以制作平方和表。题目数据要求$N<=5*10^6$,而$230^2=5290000$,所以我们可以开一个$2300$大小的数组先存储其平方和,然后暴力枚举。
C++ 代码
#include <iostream>
#include <cmath>
using namespace std;
int a[2305];
int main()
{
for (int i = 0; i <= 2300; i++)
{
a[i] = i * i;
}
int n;
cin >> n;
for (int i = 0; i < 1000; i++)
{
for (int j = i; j < 1000; j++)
{
for (int k = j; k < 2300; k++)
{
int x = n - a[i] - a[j] - a[k];
if (x == (int)sqrt(x) * (int)sqrt(x) && sqrt(x) >= k)// x是完全平方,且sqrt(x) >= k
{
cout << i << " " << j << " " << k << " " << sqrt(x) << endl;
return 0;
}
}
}
}
return 0;
}
算法2(哈希优化) $O(n^2)$
第一种做法是枚举$a,b,c$,然后计算出d,判断d是否是完全平方。那么我们要优化到$O(n^2)$,那么怎么做呢?这里我们可以枚举$c,d$,然后判断$a^2+b^2=n-c^2+d^2$是否有解。具体做法是使用unordred_map,建立一个从$c^2+d^2$到$c$的映射,代码实现如下。
C++ 代码
#include <iostream>
#include <unordered_map>
#include <cmath>
using namespace std;
unordered_map<int, int> mmp;
int main()
{
int n;
cin >> n;
// 预处理
for (int c = 0; c * c <= n / 2; c++)
{
for (int d = c; c * c + d * d <= n; d++)
{
/*c^2+d^2的值在mmp中找不到,说明之前已经存在过。
如果之前已经存在了,那么我们就不需要再预处理了。
*/
if(mmp.find(c * c + d * d) == mmp.end())
{
mmp[c * c + d * d] = c;
}
}
}
for (int a = 0; a * a <= n / 4; a++)
{
for (int b = a; a * a + b * b <= n / 2; b++)
{
if (mmp.find(n - a * a - b * b) != mmp.end())
{
/* 这里n - a * a - b * b 是计算出 c ^ c + d ^ d的值,然后
在mmp中查找是否有与之对应的 c 和 d
*/
int c = mmp[n - a * a - b * b];
int d = sqrt(n - a * a - b * b - c * c);
cout << a << " " << b << " " << c << " " << d << endl;
return 0;// 找到一个解就退出程序
}
}
}
return 0;
}
不过不知道为什么这种方法的运行时间要比第一种长
算法3(二分)
C++代码
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 2500010;
struct Sum{
int s, c, d;
bool operator < (const Sum &t) const
{
if (s != t.s) return s < t.s;
if (c != t.c) return c < t.c;
return d < t.d;
}
}sum[N];
int n, m;
int main()
{
cin >> n;
for (int c = 0; c * c <= n; c++)
{
for (int d = c; c * c + d * d <= n; d++)
{
sum[m++] = {c * c + d * d, c, d};
}
}
sort(sum, sum + m);
for (int a = 0; a * a <= n; a++)
{
for (int b = a; a * a + b * b <= n; b++)
{
int t = n - a * a - b * b;
int l = 0, r = m - 1;
while (l < r)
{
int mid = l + r >> 1;
if (sum[mid].s >= t)
{
r = mid;
}
else
{
l = mid + 1;
}
}
if (sum[l].s == t)
{
printf("%d %d %d %d\n", a, b, sum[l].c, sum[l].d);
return 0;
}
}
}
return 0;
}