(二分+线段树)非常巧妙的思维题!
题目描述
在 $2016$ 年,佳媛姐姐喜欢上了数字序列。因而她经常研究关于序列的一些奇奇怪怪的问题,现在她在研究一个难题,需要你来帮助她。
这个难题是这样子的:给出一个 $1$ 到 $n$ 的排列,现在对这个排列序列进行 $m$ 次局部排序,排序分为两种:
0 l r
表示将区间 [$l,r]$ 的数字升序排序1 l r
表示将区间 $[l,r]$ 的数字降序排序
注意,这里是对下标在区间 $[l,r]$ 内的数排序。
最后询问第 $q$ 位置上的数字。
输入格式
输入数据的第一行为两个整数 $n$ 和 $m$,$n$ 表示序列的长度,$m$ 表示局部排序的次数。
第二行为 $n$ 个整数,表示 $1$ 到 $n$ 的一个排列。
接下来输入 $m$ 行,每一行有三个整数 $\text{op},l,r$,为 $0$ 代表升序排序,$\text{op}$ 为 $1$ 代表降序排序, $l,r$ 表示排序的区间。
最后输入一个整数 $q$,表示排序完之后询问的位置
输出格式
输出数据仅有一行,一个整数,表示按照顺序将全部的部分排序结束后第 $q$ 位置上的数字。
输入 #1
6 3
1 6 2 5 3 4
0 1 4
1 3 6
0 2 4
3
输出 #1
5
说明/提示
河北省选2016第一天第二题。
对于 $30\%$ 的数据,$n,m\leq 1000$
对于 $100\%$ 的数据,$n,m\leq 10^5$,$1\leq q\leq n$
分析:
这是一个离线的做法,
因为最后询问的位置只有一个,我们可以假设这个位置最终的数字是$\text{ans}$,
将原序列中大于等于$\text{ans}$的数字变成1,小于的数字变成0,然后每次操作我们相当于对这个01序列排序,最后如果第$p$个位置仍然是1,说明这个答案$ans$可行
如何得到$ans$值呢,可以二分答案,这个答案一定是满足单调性的,对于比$ans$小的答案一定都满足题意,比$ans$大的答案一定都不满足题意,所以二分的正确性得证
接下来就转变成了用线段树维护01序列的问题,
设某区间$[l,r]$一共有$cnt$个1
那么降序排列就是把$l\sim l + cnt -1$修改为1,把$l+cnt\sim r$修改为0
升序排列同理
01序列中有多少个1就是区间求和
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200005;
struct Node
{
int l, r;
int flag, sum;
}tr[N * 4];
int n, m, x;
int w[N];
int st[N];
struct Query
{
int op, l, r;
}q[N];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
tr[u] = {l, r, -1};
if(l == r)
{
tr[u].sum = st[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if(root.flag != -1)
{
left.flag = right.flag = root.flag;
left.sum = (left.r - left.l + 1) * root.flag;
right.sum = (right.r - right.l + 1) * root.flag;
root.flag = -1;
}
}
void modify(int u, int l, int r, int x)
{
if(l <= tr[u].l && r >= tr[u].r)
{
tr[u].sum = (tr[u].r - tr[u].l + 1) * x;
tr[u].flag = x;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, x);
if(r > mid) modify(u << 1 | 1, l, r, x);
pushup(u);
}
int query(int u, int l, int r)
{
if(l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if(l <= mid) res += query(u << 1, l, r);
if(r > mid) res += query(u << 1 | 1, l, r);
return res;
}
bool check(int mid)
{
for(int i = 1; i <= n; i ++ )
{
st[i] = (w[i] >= mid);
}
build(1, 1, n);
for(int i = 0; i < m; i ++ )
{
int op = q[i].op, l = q[i].l, r = q[i].r;
if(op == 0)
{
int cnt = query(1, l, r);
modify(1, r - cnt + 1, r, 1);
modify(1, l, r - cnt, 0);
}
else
{
int cnt = query(1, l, r);
modify(1, l, l + cnt - 1, 1);
modify(1, l + cnt, r, 0);
}
}
return query(1, x, x);
}
template<typename T>void in(T &x) {
char ch = getchar();bool flag = 0;x = 0;
while(ch < '0' || ch > '9') flag |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
if(flag) x = -x;return ;
}
template <typename T>
inline void print(T x)
{
if(x<0)
{
putchar('-');
x=-x;
}
if(x>9)
print(x/10);
putchar(x%10+'0');
}
int main()
{
in(n), in(m);
for(int i = 1; i <= n; i ++ ) in(w[i]);
for(int i = 0; i < m; i ++ )
{
int op, l, r;
in(op), in(l), in(r);
q[i] = {op, l, r};
}
in(x);
int l = 1, r = n;
while(l < r)
{
int mid = l + r + 1 >> 1;
if(check(mid)) l = mid;
else r = mid - 1;
}
print(l);
return 0;
}