超详细的平衡树代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, INF = 1e8;
int n;
struct node
{
int l,r;
int key,val;//节点键值 || 权值
int cnt,size;//这个数出现次数 || 每个(节点)子树里数的个数
}tr[N];
int root,idx;
void pushup(int p)
{
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
//左儿子树 数个数+ 右儿子树 数个数+ 当前节点数出现次数(重复次数)
}
int get_node(int key)
{
tr[++idx].key=key;
tr[idx].val=rand();
tr[idx].cnt=tr[idx].size=1;
//默认创建时都是叶子节点
return idx;
}
/*
x (p) y(q) y(p) q=tr[p].l;
/ \ \ tr[p].l=tr[q].r
y (q) (x>z>y,右旋) x(p)交换 x(q) tr[q].l=p
\ / / p=q;
z z z
*/
void zig(int &p)
{
int q=tr[p].l;//存储图中的y
tr[p].l=tr[q].r;//把z从y的右儿子变成x的左儿子
tr[q].r=p;//把z从y的右儿子变成x的左儿子
p=q;//交换p,q
pushup(tr[p].r),pushup(p);//先更新q,再更新p
}
/*
起初y是右儿子 左旋,y是左儿子 右旋
x(p) y(q) y(p) q=tr[p].r;
\ / / tr[p].r=tr[q].l;
y(q) (y>z>x左旋) x(p) 交换 x(q) tr[p].l=q;
/ \ \ p=q;
z z z
*/
void zag(int &p)
{
int q=tr[p].r;//存储图中的y
tr[p].r=tr[q].l;//把z从y的左儿子变成x的右儿子
tr[q].l=p;//y的左儿子变成x
p=q;//交换
pushup(tr[p].l),pushup(p);
}
void build()
{
get_node(-INF),get_node(INF);
root=1,tr[1].r=2;//根节点是1号节点,1号点的右儿子是2号点
pushup(root);
if(tr[1].val<tr[2].val) zag(root);//左旋
}
void insert(int &p,int key)
{
//插入的同时要维护大根堆
if(!p) p=get_node(key);//递归到叶子节点fa的tr[fa].l or tr[fa].r为空,重新建一个节点
else if(tr[p].key==key) tr[p].cnt++;//插入已有的值,cnt++;
else if(tr[p].key>key)//插入的值放在左子树
{
insert(tr[p].l,key);
if(tr[tr[p].l].val>tr[p].val) zig(p);//右旋
}
else//插入的值放在右子树
{
insert(tr[p].r,key);
if(tr[tr[p].r].val>tr[p].val) zag(p);//左旋
}
pushup(p);//更新
}
void remove(int &p,int key)
{
if(!p) return;//如果删除的值不存在
if(tr[p].key==key)
{
if(tr[p].cnt>1) tr[p].cnt--;
else if(tr[p].l||tr[p].r)
{
if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
{
//右儿子不存在或左.val(y)>右.val(z)
zig(p);//左旋
/*
x y
/ \ 变 \ 一直把x往叶节点挤
y z x
\ / \
w w z
*/
remove(tr[p].r,key);
}
else
{
//左儿子不存在或左.val(y)<右.val(z)
zag(p);//右旋
/*
x z
/ \ 变 /
y z x 一直把x往叶节点挤
\ /
w y
\
w
*/
remove(tr[p].l,key);
}
}
else p=0;//左儿子和右儿子都不存在,叶子节点直接删除
}
else if(tr[p].key>key) remove(tr[p].l,key);//如果要删除的key<当前节点的key
//往左子树里找
else remove(tr[p].r,key);//否则往右子树里找
pushup(p);
}
int get_rank_by_key(int p,int key)
{
if(!p) return 0;//p不存在,也就是没有找到key,返回0
if(tr[p].key==key) return tr[tr[p].l].size+1;
if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
//要找的key比当前节点小,去当前节点的左子树找
return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
//要找的key比当前节点大,去当前节点的右子树找
}
int get_key_by_rank(int p,int rank)
{
if(!p) return INF;//如果遍历到了叶子节点的儿子,说明这个rank不在
//min~max的范围内,因为rank>0,返回INF;
if(tr[tr[p].l].size>=rank) return get_key_by_rank(tr[p].l,rank);
//如果rank小于等于左子树的节点数量,往左子树里找
if(tr[tr[p].l].size+tr[p].cnt>=rank) return tr[p].key;
//如果rank小于等于(除了右子树的所有节点数量),说明tr[p].size<rank,但是
//tr[p].size+tr[p].cnt>=rank,表明这个rank实在tr[p].cnt里出现,于是就是tr[p].key
return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
//如果rank大于(除了右子树的所有节点数量),往右子树里找,注意同时往下传的rank-左子树-当前节点个数
}
int get_prev(int p,int key)
{
if(!p) return -INF;//没有前驱,说明是最小的数
if(tr[p].key>=key) return get_prev(tr[p].l,key);
//当前这个数比key大,往左子树里找
return max(tr[p].key,get_prev(tr[p].r,key));
//当前这个数比key小,往右子树里找
}
int get_next(int p,int key)
{
if(!p) return INF;//没后继,说明是最大的数
if(tr[p].key<=key) return get_next(tr[p].r,key);
//当前这个数比key小,往右子树里找
return min(tr[p].key,get_next(tr[p].l,key));
//当前这个数比key小,往右子树里找
}
int main()
{
build();
scanf("%d",&n);
while(n--)
{
int op,x;
scanf("%d%d",&op,&x);
if(op==1) insert(root,x);
else if(op==2) remove(root,x);
else if (op == 3) printf("%d\n", get_rank_by_key(root, x) - 1);
else if (op == 4) printf("%d\n", get_key_by_rank(root, x + 1));
else if (op == 5) printf("%d\n", get_prev(root, x));
else printf("%d\n", get_next(root, x));
}
return 0;
}