题目描述
自己看题去,反正很短(bushi
样例
$\texttt{Input:}$
2
5 1 49
8 2 1 7 9
5 1 64
8 2 1 7 9
$\texttt{Output:}$
2
1
分析
计算校验值
首先考虑校验值怎么计算。
显然大家可以感觉到应该大小配对,那么我们可以使用邻交换法证明。
设有四个数 $a$,$b$,$c$,$d$ 满足 $a<=b<=c<=d$,只需证明 $(d-a)^2+(c-b)^2 \ge (d-b)^2+(c-a)^2$ 即可。我们变形一下:
$(d-a)^2+(c-b)^2=a^2+b^2+c^2+d^2-2ad-2bc$,$(d-b)^2+(c-a)^2=a^2+b^2+c^2+d^2-2bd-2ac$
只需比较 $ad+bc$ 和 $ac+bd$ 即可。
作差:
$ad+bc-ac-bd=a(d-c)+b(c-d)=(a-b)(d-c)\le 0$,所以 $ad+bc\le ac+bd$,故 $(d-a)^2+(c-b)^2 \ge (d-b)^2+(c-a)^2$。
证毕。
计算答案
这里很容易想到让每一段区间尽量长,这样段数就会尽量少。显然区间越长,这个区间的校验值就越大,所以我们可以考虑二分 or 倍增。
这道题倍增更好。感性理解一下。如果区间长度只能是 2 的话,二分会排序 $log_2n$ 次,时间复杂度为 $O(n log_2^2n)$,不够优秀。倍增的话,采用走 $2^k$ 步,如果能成功就加大步伐,也就是 $k \gets k+1$,否则 $k \gets k-1$ 的策略。如果区间长度比较短,那么很快就可以得到答案;如果区间长度长,那么尽管找这一个区间的时间长,但是剩下的长度较短,总时间也比较短,平均一下还是倍增比较快。而且倍增如果成功,只需要排序新的区间再合并,没有必要排序整个区间。
解决
需要注意每次扩展的时候要还原数组:新扩展的区域还原成原序列,旧的还原成排好序后的序列。
这道题时间限制 10s,但是我们可以在 1s 内解决。
首先很容易写出这样的代码(好吧我写了一个小时):
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 500005;
int n, m, K, a[N], backup[N], ordered[N], tmp[N];
long long T;
inline void merge(int l, int mid, int r) {
for (int i = l, j = mid + 1, k = l; k <= r; k++)
if (j > r || (i <= mid && a[i] <= a[j]))
tmp[k] = a[i++];
else
tmp[k] = a[j++];
for (int i = l; i <= r; i++) a[i] = tmp[i];
}
inline bool check(int l, int r) {
long long res = 0;
for (int i = l, j = r, k = 1; k <= m && i < j; i++, j--, k++) {
long long t = (long long)a[i] - a[j];
res += t * t;
if (res > T)
return false;
}
return true;
}
inline bool expand(int st, int ed, int dis) {
for (int i = st; i <= ed; i++) a[i] = ordered[i];
for (int i = 1; i <= dis; i++) a[i + ed] = backup[i + ed];
sort(&a[ed + 1], &a[ed + dis + 1]);
merge(st, ed, ed + dis);
bool res = check(st, ed + dis);
if (res)
for (int i = st; i <= ed + dis; i++) ordered[i] = a[i];
return res;
}
inline int solve() {
int st = 1, ans = 0;
for (int i = 1; i <= n; i++) backup[i] = a[i];
while (st <= n) {
int d = -1, cur = st - 1;
do {
if (cur == n)
break;
int dis = (d == -1) ? 0 : (1 << d);
if (cur + dis > n)
dis = n - cur;
if (expand(st, cur, dis))
cur += dis, d++;
else
d--;
} while (d != -1);
ans++, st = cur + 1;
}
return ans;
}
template <class T>
inline void read(T& ret) {
ret = 0;
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
}
int main() {
read(K);
while (K--) {
read(n), read(m), read(T);
for (int i = 1; i <= n; i++) read(a[i]);
printf("%d\n", solve());
}
return 0;
}
但是要跑 1.14s,太慢了!(尽管可以通过这道题了)
我们注意到 $a$ 的范围不大,所以我们在排序的时候考虑比较 sort 和计数排序的效率,使用更快的那一个。代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <assert.h>
using namespace std;
const int N = 500005, SIZE = (1 << 20) + 5;
int n, m, K, a[N], backup[N], ordered[N], tmp[N], bucket[SIZE];
long long T;
inline void merge(int l, int mid, int r) {
for (int i = l, j = mid + 1, k = l; k <= r; k++)
if (j > r || (i <= mid && a[i] <= a[j]))
tmp[k] = a[i++];
else
tmp[k] = a[j++];
for (int i = l; i <= r; i++) a[i] = tmp[i];
}
inline bool check(int l, int r) {
long long res = 0;
for (int i = l, j = r, k = 1; k <= m && i < j; i++, j--, k++) {
long long t = (long long)a[i] - a[j];
res += t * t;
if (res > T)
return false;
}
return true;
}
void mysort(int st, int ed) {
if (st > ed)
return;
int len = ed - st + 1, mini = 1 << 20, maxi = 0;
for (int i = st; i <= ed; i++) mini = min(mini, a[i]), maxi = max(maxi, a[i]);
if (len * log2(len) <= (long long)maxi - mini) {
sort(&a[st], &a[ed + 1]);
// printf("st = %d, ed = %d, option = STL\n", st, ed);
return;
}
// printf("st = %d, ed = %d, option = Bucket Sort\n", st, ed);
for (int i = st; i <= ed; i++) bucket[a[i]]++;
for (int i = mini, j = st; i <= maxi && j <= ed; i++)
if (bucket[i]) {
for (int k = 1; k <= bucket[i]; k++) a[j++] = i;
bucket[i] = 0;
}
/*
puts("result:");
for (int i = st; i <= ed; i++) printf("%d ", a[i]);
puts("");
*/
}
inline bool expand(int st, int ed, int dis) {
for (int i = st; i <= ed; i++) a[i] = ordered[i];
for (int i = 1; i <= dis; i++) a[i + ed] = backup[i + ed];
mysort(ed + 1, ed + dis);
merge(st, ed, ed + dis);
bool res = check(st, ed + dis);
if (res)
for (int i = st; i <= ed + dis; i++) ordered[i] = a[i];
return res;
}
inline int solve() {
int st = 1, ans = 0;
for (int i = 1; i <= n; i++) backup[i] = a[i];
while (st <= n) {
int d = -1, cur = st - 1;
do {
if (cur == n)
break;
int dis = (d == -1) ? 0 : (1 << d);
if (cur + dis > n)
dis = n - cur;
if (expand(st, cur, dis))
cur += dis, d++;
else
d--;
} while (d != -1);
ans++, st = cur + 1;
}
return ans;
}
template <class T>
inline void read(T& ret) {
ret = 0;
char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
}
int main() {
read(K);
while (K--) {
read(n), read(m), read(T);
assert(n <= 5e5 && m <= 5e5 && T <= 1e18);
for (int i = 1; i <= n; i++) read(a[i]), assert(0 <= a[i] && a[i] <= 1 << 20);
printf("%d\n", solve());
}
return 0;
}
快一些了,但是还是要 1.05s。
终极大优化:register!
最终代码:
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
const int N = 500005, SIZE = (1 << 20) + 5;
int n, m, K, a[N], backup[N], ordered[N], tmp[N], bucket[SIZE];
long long T;
inline void merge(int l, int mid, int r) {
for (register int i = l, j = mid + 1, k = l; k <= r; k++)
if (j > r || (i <= mid && a[i] <= a[j]))
tmp[k] = a[i++];
else
tmp[k] = a[j++];
for (register int i = l; i <= r; i++) a[i] = tmp[i];
}
inline bool check(int l, int r) {
long long res = 0;
for (register int i = l, j = r, k = 1; k <= m && i < j; i++, j--, k++) {
long long t = (long long)a[i] - a[j];
res += t * t;
if (res > T)
return false;
}
return true;
}
void mysort(int st, int ed) {
if (st > ed)
return;
int len = ed - st + 1, mini = 1 << 20, maxi = 0;
for (register int i = st; i <= ed; i++) mini = min(mini, a[i]), maxi = max(maxi, a[i]);
if (len * log2(len) <= (long long)maxi - mini) {
sort(&a[st], &a[ed + 1]);
return;
};
for (register int i = st; i <= ed; i++) bucket[a[i]]++;
for (register int i = mini, j = st; i <= maxi && j <= ed; i++) {
for (register int k = 1; k <= bucket[i]; k++) a[j++] = i;
bucket[i] = 0;
}
}
inline bool expand(int st, int ed, int dis) {
for (register int i = st; i <= ed; i++) a[i] = ordered[i];
for (register int i = 1; i <= dis; i++) a[i + ed] = backup[i + ed];
mysort(ed + 1, ed + dis);
merge(st, ed, ed + dis);
bool res = check(st, ed + dis);
if (res)
for (register int i = st; i <= ed + dis; i++) ordered[i] = a[i];
return res;
}
inline int solve() {
int st = 1, ans = 0;
for (register int i = 1; i <= n; i++) backup[i] = a[i];
while (st <= n) {
int d = -1, cur = st - 1;
do {
if (cur == n)
break;
int dis = (d == -1) ? 0 : (1 << d);
if (cur + dis > n)
dis = n - cur;
if (expand(st, cur, dis))
cur += dis, d++;
else
d--;
} while (d != -1);
ans++, st = cur + 1;
}
return ans;
}
char buf[1 << 21], * p1 = buf, * p2 = p1 + fread(buf, 1, 1 << 21, stdin);
inline char gc() {
if (p1 == p2) {
p1 = buf, p2 = p1 + fread(buf, 1, 1 << 21, stdin);
if (p1 == p2)
return EOF;
}
return *p1++;
}
template <class T>
inline void read(T& ret) {
ret = 0;
char ch = gc();
while (ch < '0' || ch > '9') ch = gc();
while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = gc();
}
int main() {
read(K);
while (K--) {
read(n), read(m), read(T);
for (register int i = 1; i <= n; i++) read(a[i]);
printf("%d\n", solve());
}
return 0;
}
这下可以跑进 1s了!
开了 O2 只需要 353ms。
%%%好文