发现 $t$ 可能 $\lt 0$,即 $st$ 不再单调递增。
此时就需要维护整个下凸壳,而不是仅仅维护满足条件的一部分。因为弹出队头之后可能还要用到它。
然后在下凸壳上二分求一个点,使得它左边线段的斜率 $\lt st_i + S$,右边线段斜率 $\geq st_i + S$。
于是这道题就被秒了。
愉快地写完二分之后,发现它只过了一半的测试点。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 15;
int n, S, c[N], t[N];
int sc[N], st[N];
long long dp[N];
double x(int p) { return sc[p]; }
double y(int p) { return dp[p]; }
double slope(int a, int b) {
return ((y(a) - y(b)) * 1.0 / (x(a) - x(b))) - S;
}
int q[N], l, r;
int main() {
scanf("%d%d", &n, &S);
for (int i = 1; i <= n; i++) scanf("%d%d", &t[i], &c[i]);
for (int i = 1; i <= n; i++) st[i] = st[i - 1] + t[i], sc[i] = sc[i - 1] + c[i];
memset(dp, 0x3f, sizeof dp);
dp[0] = 0;
l = r = 1; q[l] = 0;
for (int i = 1; i <= n; i++) {
// while (l < r && slope(q[l], q[l + 1]) <= st[i]) l++; //需要维护整个凸壳,所以不能弹出队头
int L = l, R = r;
while (L < R) {
int mid = L + R >> 1;
if (slope(q[mid], q[mid + 1]) <= st[i]) L = mid + 1;
else R = mid;
}
int j = q[L];
dp[i] = min(dp[i], dp[j] + (sc[i] - sc[j]) * 1ll * st[i] + (sc[n] - sc[j]) * 1ll * S);
while (l < r && slope(q[r - 1], q[r]) >= slope(q[r], i)) r--; //维护下凸壳
q[++r] = i;
}
printf("%lld\n", dp[n]);
return 0;
}
发现出题人善良地卡了精度,所以需要改成交叉相乘的形式。
然后它就过了。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 15;
int n, S, c[N], t[N];
int sc[N], st[N];
long long dp[N];
long double x(int p) { return sc[p]; }
long double y(int p) { return dp[p]; }
long double slope(int a, int b) {
return ((y(a) - y(b)) * 1.0 / (x(a) - x(b))) - S;
}
int q[N], l, r;
int main() {
scanf("%d%d", &n, &S);
for (int i = 1; i <= n; i++) scanf("%d%d", &t[i], &c[i]);
for (int i = 1; i <= n; i++) st[i] = st[i - 1] + t[i], sc[i] = sc[i - 1] + c[i];
memset(dp, 0x3f, sizeof dp);
dp[0] = 0;
l = r = 1; q[l] = 0;
for (int i = 1; i <= n; i++) {
int L = l, R = r;
while (L < R) {
int mid = L + R >> 1;
int a = q[mid], b = q[mid + 1];
if (y(a) - y(b) >= (st[i] + S) * 1ll * (x(a) - x(b))) L = mid + 1;
else R = mid;
}
int j = q[L];
dp[i] = min(dp[i], dp[j] + (sc[i] - sc[j]) * 1ll * st[i] + (sc[n] - sc[j]) * 1ll * S);
while (l < r) {
int a = q[r - 1], b = q[r], c = i;
if ((y(a) - y(b)) * 1ll * (x(b) - x(c)) >= (y(b) - y(c)) * 1ll * (x(a) - x(b))) r--;
else break;
}
q[++r] = i;
}
printf("%lld\n", dp[n]);
return 0;
}