倍增 90tps
当确定左端点L后,倍增找到右端点R,满足区间[L, R]最长并且校验值 <= k
统计段数即可
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int N = 5e5 + 10;
int n, m;
LL k, P[N], tmp[N];
bool check(int l, int r)
{
LL sum = 0;
for (int i = l; i <= r; i++)
tmp[i - l] = P[i];
sort(tmp, tmp + r - l + 1);
for (int i = 0, j = r - l, k = 1; i < j && k <= m; i++, j--, k++)
sum += pow(tmp[i] - tmp[j], 2);
return sum <= k;
}
int solve()
{
if (n == 1)
return 1;
int ans = 0;
for (int l = 1; l <= n; )
{
int p = 1, r = l;
while(p != 0)
{
if (r + p <= n && check(l, r + p))
{
r += p, p *= 2;
}
else
p /= 2;
}
ans++, l = r + 1;
}
return ans;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
cin >> n >> m >> k;
for (int i = 1; i <= n; i++)
cin >> P[i];
cout << solve() << '\n';
}
return 0;
}
倍增 + 归并优化 100pts
每次倍增求校验值时,可以不用对完整的部分进行快速排序,而只需要采用类似归并排序的方法,只对新增的长度部分排序,然后合并新旧两段
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int N = 5e5 + 10;
int n, m;
LL k, P[N], tmp[N], a[N]; // tmp存排序数组 a存合并的数组
bool check(int l, int r, int p)
{
// 之前[l, r]已经排好了 新的要排的区间[r + 1, r + p]
for (int i = r + 1; i <= r + p; i++)
tmp[i] = P[i];
sort(tmp + r + 1, tmp + r + p + 1);
// 合并tmp[l, r]和tmp[r + 1, r + p]到a[0, r + p - l]
int cnt = 0, i = l, j = r + 1;
while (i <= r && j <= r + p)
{
if (tmp[i] <= tmp[j])
a[cnt++] = tmp[i++];
else
a[cnt++] = tmp[j++];
}
while (i <= r)
a[cnt++] = tmp[i++];
while (j <= r + p)
a[cnt++] = tmp[j++];
LL sum = 0;
for (int i = 0, j = cnt - 1, t = 1; t <= m && i < j; i++, j--, t++)
sum += pow(a[j] - a[i], 2);
return sum <= k;
}
int solve()
{
if (n == 1)
return 1;
int ans = 0;
for (int l = 1; l <= n; )
{
int p = 1, r = l;
tmp[l] = P[l]; // 一开始P[l]本身自己已经排好序了 第一次倍增会先计算区间[l, l+1]的校验值
// check里已经默认[l, r]已经有序了 即[l, r]已经有序了 所以要把P[l]放到tmp[l]里,不然tmp[l]存的是0
while(p != 0)
{
if (r + p <= n && check(l, r, p)) {
for (int i = l; i <= r + p; i++)
tmp[i] = a[i - l];
r += p, p *= 2;
}
else
p /= 2;
}
ans++, l = r + 1;
}
return ans;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
cin >> n >> m >> k;
for (int i = 1; i <= n; i++)
cin >> P[i];
cout << solve() << '\n';
}
return 0;
}