平衡树详解 https://blog.csdn.net/sjystone/article/details/115443239
操作
1.插入数x
2.删除一个数(若有多个数相同,只删除一个)
3.查询x的排名(若有多个相同的数,输出最小的排名)
4.查询排名为x的数
5.求x的前驱
6.求x的后继
0.pushup用儿子计算父节点信息
void pushup(int p)
{
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
1.储存信息
struct Node{
int l, r;
int key, val;
int cnt; 该数出现的次数
int size; 子树中数的个数(含当前的点)
}
2.创建节点(叶节点)
int get_node(int key)
{
tr[++ idx].key = key; //分配一个节点
tr[idx].val = rand(); //val是一个随机值
tr[idx].cnt = 1;
tr[idx].size = 1;
return idx;
}
3.build建立
void build()
{
get_node(-INF), get_node(INF); //哨兵,注意:排名时考虑加一减一
root = 1; //跟节点是一号点
tr[1].r = 2; //跟节点的右儿子是二号点
pushup(root);
}
4.zig右旋
void zig(int &p) //p相当于指针
{
int q = tr[p].l;
tr[p].l = tr[q].r //q的右儿子,接到p的左儿子
tr[q].r = p; //q的右儿子是p
p = q; //q变成子树的根,即传过来指向p的指针变成指向q
pushup(tr[p].r); /需要通过子节点更新的节点为p和p的右儿子
pushup(p);
}
5.zag左旋
void zag(int &p) //跟右旋反过来
{
int q = tr[p].r;
tr[p].r = tr[q].l;
tr[q].l = p;
p = q;
pushup(tr[p].l);
pushup(p);
}
6.insert插入
void insert(int &p, int key) //引用!!
{
if ( !p ) p = get_node(key); //若p不存在,则创建叶节点
else if ( tr[p].key == key ) tr[p].cnt ++; //若已经存在key,则该节点cnt++
else if ( tr[p].key > key )
{
insert(tr[p].l, key); //往左走
if ( tr[tr[p].l].val > tr[p].val ) zig(p) //插入后,判断是否满足堆,不满足的话右旋
}
else
{
insert(tr[p].r, key);
if ( tr[tr[p].r].val > tr[p].val ) zag(p);
}
pushup(p); //更新信息
}
7.remove删除
void remove(int &p, int key) //引用
{
if ( !p ) return; //如果不存在,返回
if ( tr[p].key == key ) //如果key相同
{
if ( tr[p].cnt > 1 ) tr[p].cnt --; //若该数的个数大于一,直接减去一个
else if ( tr[p].l || tr[p].r ) //否则,若当前节点不是叶子节点(左儿子/右儿子存在)
{
if ( !tr[p].r || tr[tr[p].l].val > tr[tr[p].r].val ) //若右子树是空的,或左儿子val大于右儿子val,右旋
{
zig(p);
remove(tr[p].r, key); //右旋完后,要删除的数成了原本位置的右儿子,继续删除此数
}
else //否则左旋
{
zag(p);
remove(tr[p].l, key);
}
}
else p = 0; //若是叶子节点。直接删除
}
else if ( tr[p].key > key ) remove(tr[p].l, key); //往左边找
else remove(tr[p].r, key); //往右边找
pushup(p); //更新信息
}
8.get_rank根据数值查排名
int get_rank(int p, int key)
{
if ( !p ) return 0; //若p不存在(该题不存在这种情况)
if ( tr[p].key == key ) return tr[tr[p].l].size + 1; //返回左子树数的个数加1
if ( tr[p].key > key ) return get_rank(tr[p].l, key); //往左子树中找,返回左子树中的排名
return tr[tr[p].l].size + tr[p].cnt + get_rank(tr[p].r, key);
//往右子树找,返回左子树的size + 当前节点的cnt + 右子树里的排名
}
9.get_key根据排名查数值
int get_key(int p, int rank)
{
if ( !p ) return INF; //若p不存在(该题不存在这种情况)
if ( tr[tr[p].l].size >= rank ) return get_key(tr[p].l, rank); //在左子树找
if ( tr[tr[p].l].size + tr[p].cnt >= rank ) return tr[p].key; //就是当前的key
return get_key(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt); //找在右子树中排名x的数
//x等于原本排名,减去左子树的size,减去当前节点的cnt
}
10.get_prev找前驱
int get_prev(int p, int key)
{
if ( !p ) return -INF;
if ( tr[p].key >= key ) return get_prev(tr[p].l, key); //右边所有数不用考虑
return max(tr[p].key, get_prev(tr[p].r, key)); //若当前值比key小,可能是当前值,也可能是当前节点的右子树中的值,取max
}
11.get_next找后继
int get_next(int p, int key)
{
if ( !p ) return INF;
if ( tr[p].key <= key ) return get_next(tr[p].r, key);
return min(tr[p].key, get_next(tr[p].l, key));
}
代码如下:
#include <iostream>
#include <cstdio>
#include <ctime>
using namespace std;
const int N = 100010, INF = 1e8;
int m;
int root, idx;
struct Node{
int l, r;
int key, val;
int size, cnt;
}tr[N];
void pushup(int p)
{
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
void zig(int &p)
{
int q = tr[p].l;
tr[p].l = tr[q].r;
tr[q].r = p;
p = q;
pushup(tr[p].r);
pushup(p);
}
void zag(int &p)
{
int q = tr[p].r;
tr[p].r = tr[q].l;
tr[q].l = p;
p = q;
pushup(tr[p].l);
pushup(p);
}
int get_node(int key)
{
tr[++ idx].key = key;
tr[idx].val = rand();
tr[idx].cnt = 1;
tr[idx].size = 1;
return idx;
}
void build()
{
root = get_node(-INF);
tr[root].r = get_node(INF);
pushup(root);
}
void insert(int &p, int key)
{
if ( !p ) p = get_node(key);
else if ( tr[p].key == key ) tr[p].cnt ++;
else if ( tr[p].key > key )
{
insert(tr[p].l, key);
if ( tr[tr[p].l].val > tr[p].val ) zig(p);
}
else
{
insert(tr[p].r, key);
if ( tr[tr[p].r].val > tr[p].val ) zag(p);
}
pushup(p);
}
void remove(int &p, int key)
{
if ( !p ) return;
if ( tr[p].key == key )
{
if ( tr[p].cnt > 1 ) tr[p].cnt --;
else if ( tr[p].l || tr[p].r )
{
if ( !tr[p].r || tr[tr[p].l].val > tr[tr[p].r].val )
{
zig(p);
remove(tr[p].r, key);
}
else
{
zag(p);
remove(tr[p].l, key);
}
}
else p = 0;
}
else if ( tr[p].key > key ) remove(tr[p].l, key);
else remove(tr[p].r, key);
pushup(p);
}
int get_rank(int p, int key)
{
if ( !p ) return 0;
if ( tr[p].key == key ) return tr[tr[p].l].size + 1;
if ( tr[p].key > key ) return get_rank(tr[p].l, key);
return tr[tr[p].l].size + tr[p].cnt + get_rank(tr[p].r, key);
}
int get_key(int p, int rank)
{
if ( !p ) return INF;
if ( tr[tr[p].l].size >= rank ) return get_key(tr[p].l, rank);
if ( tr[p].cnt + tr[tr[p].l].size >= rank ) return tr[p].key;
return get_key(tr[p].r, rank - tr[p].cnt - tr[tr[p].l].size);
}
int get_prev(int p, int key)
{
if ( !p ) return -INF;
if ( tr[p].key >= key ) return get_prev(tr[p].l, key);
return max(tr[p].key, get_prev(tr[p].r, key));
}
int get_next(int p, int key)
{
if ( !p ) return INF;
if ( tr[p].key <= key ) return get_next(tr[p].r, key);
return min(tr[p].key, get_next(tr[p].l, key));
}
int main()
{
build();
cin >> m;
while ( m -- )
{
int op, x;
scanf("%d%d", &op, &x);
if ( op == 1 ) insert(root, x);
else if ( op == 2 ) remove(root, x);
else if ( op == 3 ) printf("%d\n", get_rank(root, x) - 1);
else if ( op == 4 ) printf("%d\n", get_key(root, x + 1));
else if ( op == 5 ) printf("%d\n", get_prev(root, x));
else if ( op == 6 ) printf("%d\n", get_next(root, x));
}
return 0;
}