题目描述
给出一个可重集 $a$(编号为 $1$),它支持以下操作:
0 p x y
:将可重集 $p$ 中大于等于 $x$ 且小于等于 $y$ 的值放入一个新的可重集中(新可重集编号为从 $2$ 开始的正整数,是上一次产生的新可重集的编号+1)。
1 p t
:将可重集 $t$ 中的数放入可重集 $p$,且清空可重集 $t$(数据保证在此后的操作中不会出现可重集 $t$)。
2 p x q
:在 $p$ 这个可重集中加入 $x$ 个数字 $q$。
3 p x y
:查询可重集 $p$ 中大于等于 $x$ 且小于等于 $y$ 的值的个数。
4 p k
:查询在 $p$ 这个可重集中第 $k$ 小的数,不存在时输出 -1
。
输入格式
第一行两个整数 $n,m$,表示可重集中的数在 $1\sim n$ 的范围内,有 $m$ 个操作。
接下来一行 $n$ 个整数,表示 $1 \sim n$ 这些数在 $a$ 中出现的次数 $(a_{i} \leq m)$。
接下来的 $m$ 行每行若干个整数,第一个数为操作的编号 $opt$($0 \leq opt \leq 4$),以题目描述为准。
输出格式
依次输出每个查询操作的答案。
输入 #1
5 12
0 0 0 0 0
2 1 1 1
2 1 1 2
2 1 1 3
3 1 1 3
4 1 2
2 1 1 4
2 1 1 5
0 1 2 4
2 2 1 4
3 2 2 4
1 1 2
4 1 3
输出 #1
3
2
4
3
说明/提示
对于 $30\%$ 的数据,$1\leq n \leq {10}^3$,$1 \le m \le {10}^3$;
对于 $100\%$ 的数据,$1 \le n \le 2 \times {10}^5$,$1 \le x, y, q \le m \le 2 \times {10}^5$。保证数据合法。
不开 long long
见祖宗!!
分析:
线段树分裂,这道题中我们按子树大小分裂
考虑函数$split(x,y,k)$,分裂以$x$为根节点的子树,另一棵线段树为$y$,将前$k$个元素分给$x$,将剩余的分给$y$
定义$v = tr[tr[x].l].sum$
- $v < k$,左端不需要修改,依然是$x$的,递归到右边处理,$split(tr[x],r,tr[y].r, k - v)$
- $v=k$,左子树正好包含前$k$个,那么将右子树归给$y$,直接交换$x,y$的右子树即可
- $v>k$,左子树多于$k$个,先将右子树归给$y$,然后递归$x$的左子树
分裂的代码如下:
void split(int x, int &y, ll k)
{
//前k个给x,后面的给y
if(!x) return;
y = newnode();
ll v = tr[tr[x].l].sum;
if(v < k) split(tr[x].r, tr[y].r, k - v);
else
{
swap(tr[x].r, tr[y].r);
if(v > k) split(tr[x].l, tr[y].l, k);
}
tr[y].sum = tr[x].sum - k;
tr[x].sum = k;
}
对于每个操作:
操作0:将$[1,x-1],[x,y],[y+1,n]$三段分离出来,再将$[1,x-1],[y+1,n]$合并回去
操作1:直接令$root[p]=merge(root[p],root[t])$
其他操作都是简单的权值线段树的操作
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 200005;
struct Node
{
int l, r;
ll sum;
}tr[N * 50];
int n, m;
int root[N], pool[N * 20], delcnt, idx, rtidx = 1;
int newnode()
{
if(delcnt) return pool[delcnt -- ];
return (++ idx);
}
void del(int u)
{
pool[++ delcnt] = u;
tr[u].l = tr[u].r = tr[u].sum = 0;
}
void pushup(int u)
{
tr[u].sum = tr[tr[u].l].sum + tr[tr[u].r].sum;
}
void modify(int &u, int l, int r, int pos, ll val)
{
if(!u) u = newnode();
if(l == r)
{
tr[u].sum += val;
return;
}
int mid = l + r >> 1;
if(pos <= mid) modify(tr[u].l, l, mid, pos, val);
else modify(tr[u].r, mid + 1, r, pos, val);
pushup(u);
}
void split(int x, int &y, ll k)
{
//前k个给x,后面的给y
if(!x) return;
y = newnode();
ll v = tr[tr[x].l].sum;
if(v < k) split(tr[x].r, tr[y].r, k - v);
else
{
swap(tr[x].r, tr[y].r);
if(v > k) split(tr[x].l, tr[y].l, k);
}
tr[y].sum = tr[x].sum - k;
tr[x].sum = k;
}
int merge(int a, int b, int l, int r)
{
if(!a || !b) return a + b;
if(l == r)
{
tr[a].sum += tr[b].sum;
return a;
}
int mid = l + r >> 1;
tr[a].l = merge(tr[a].l, tr[b].l, l, mid);
tr[a].r = merge(tr[a].r, tr[b].r, mid + 1, r);
pushup(a);
del(b);
return a;
}
ll query(int u, int l, int r, int x, int y)
{
if(x <= l && y >= r) return tr[u].sum;
int mid = l + r >> 1;
ll res = 0;
if(x <= mid) res += query(tr[u].l, l, mid, x, y);
if(y > mid) res += query(tr[u].r, mid + 1, r, x, y);
return res;
}
ll get_k(int u, int l, int r, int k)
{
if(l == r) return l;
int mid = l + r >> 1;
if(tr[tr[u].l].sum >= k) return get_k(tr[u].l, l, mid, k);
return get_k(tr[u].r, mid + 1, r, k - tr[tr[u].l].sum);
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++ )
{
ll x;
scanf("%lld", &x);
modify(root[1], 1, n, i, x);
}
while (m -- )
{
int op;
scanf("%d", &op);
if(op == 0)
{
int p, x, y;
scanf("%d%d%d", &p, &x, &y);
ll q1 = query(root[p], 1, n, 1, y);
ll q2 = query(root[p], 1, n, x, y);
int X = ++ rtidx, Y = 0;
split(root[p], root[X], q1 - q2);
split(root[X], Y, q2);
root[p] = merge(root[p], Y, 1, n);
}
else if(op == 1)
{
int p, t;
scanf("%d%d", &p, &t);
root[p] = merge(root[p], root[t], 1, n);
}
else if(op == 2)
{
int p, x, q;
scanf("%d%d%d", &p, &x, &q);
modify(root[p], 1, n, q, x);
}
else if(op == 3)
{
int p, x, y;
scanf("%d%d%d", &p, &x, &y);
printf("%lld\n", query(root[p], 1, n, x, y));
}
else
{
int p, k;
scanf("%d%d", &p, &k);
if(tr[root[p]].sum < k) puts("-1");
else printf("%lld\n", get_k(root[p], 1, n, k));
}
}
return 0;
}