C++代码
#include<cstdio>
#include <cstring>
#include<iostream>
#include <algorithm>
using namespace std;
const int N=100010, INF =1e8;
int n;
struct Node{
int l, r;
int key, val;
int cnt, size;
} tr[N];
int root, idx;
void pushup(int p){
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
int get_node(int key){
tr[ ++ idx].key = key;
tr[idx].val =rand();
tr[idx].cnt =tr[idx].size = 1;
return idx;
}
void build(){
get_node(-INF),get_node(INF);
root =1,tr[1].r = 2;
pushup(root);
}
// 右旋
void zig(int &p){
int q = tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p = q;
pushup(tr[p].r),pushup(p);
}
// 左旋
void zag(int &p){
int q= tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p = q;
pushup(tr[p].l),pushup(p);
}
void insert(int &p, int key) {
if(!p)
p = get_node(key);
else if (tr[p].key == key)
tr[p].cnt ++ ;
else if(tr[p].key > key){
insert(tr[p].l, key);
if(tr[tr[p].l].val >tr[p].val) zig(p);
} else {
insert(tr[p].r, key);
if (tr[tr[p].r].val >tr[p].val) zag(p);
}
pushup(p);
}
void remove(int &p, int key){
if(!p) return;
if(tr[p].key == key) {
if(tr[p].cnt >1) tr[p].cnt--;
else if(tr[p].l || tr[p].r) {
if(!tr[p].r || tr[tr[p].l].val > tr[tr[p].r].val){
zig(p);
remove(tr[p].r, key);
}
else{
zag(p);
remove(tr[p].l, key);
}
}
else p=0;
}
else if(tr[p].key >key)
remove(tr[p].l, key);
else remove(tr[p].r, key);
pushup(p);
}
int get_rank_by_key(int p, int key) { // 通过数值找排名
if(!p)return 0; //本题中不会发生此情况
if(tr[p].key == key)
return tr[tr[p].l].size + 1;
if(tr[p].key >key)
return get_rank_by_key(tr[p].l, key);
return tr[tr[p].l].size + tr[p].cnt + get_rank_by_key(tr[p].r, key);
}
int get_key_by_rank(int p, int rank) { // 通过排名找数值
if(!p) return INF;//本题中不会发生此情况
if(tr[tr[p].l].size >= rank)
return get_key_by_rank(tr[p].l,rank);
if(tr[tr[p].l].size + tr[p].cnt >= rank)
return tr[p].key;
return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt);
}
int get_prev(int p, int key) {// 找到严格小于key的最大数
if(!p)
return -INF;
if(tr[p].key >= key)
return get_prev(tr[p].l, key);
return max(tr[p].key, get_prev(tr[p].r, key));
}
int get_next(int p, int key) {// 找到严格大于key的最小数
if(!p)
return INF;
if(tr[p].key<= key)
return get_next(tr[p].r, key);
return min(tr[p].key,get_next(tr[p].l, key));
}
int main(){
build();
scanf("%d",&n);
while(n--){
int opt, x;
scanf("%d%d",&opt,&x);
if(opt == 1)insert(root,x);
else if(opt ==2)remove(root, x);
else if(opt ==3)printf("%d\n",get_rank_by_key(root,x)-1);
else if(opt ==4)printf("%d\n",get_key_by_rank(root, x + 1));
else if(opt ==5)printf("%d\n",get_prev(root, x));
else printf("%d\n",get_next(root, x));
}
return 0;
}
Java代码
import java.util.*;
class Main {
static final int N = 100010, INF = 100000000;
static int n;
static class Node {
int l, r;
int key, val;
int cnt, size;
Node(int key) {
this.key = key;
this.val = new Random().nextInt();
this.cnt = this.size = 1;
}
}
static Node[] tr = new Node[N];
static int root, idx;
static void pushup(int p) {
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
static int get_node(int key) {
tr[++idx] = new Node(key);
return idx;
}
static void build() {
tr[get_node(-INF)] = new Node(INF);
root = 1;
tr[root].r = 2;
pushup(root);
}
static void zig(int[] p) {
int q = tr[p[0]].l;
tr[p[0]].l = tr[q].r;
tr[q].r = p[0];
p[0] = q;
pushup(tr[p[0]].r);
pushup(p[0]);
}
static void zag(int[] p) {
int q = tr[p[0]].r;
tr[p[0]].r = tr[q].l;
tr[q].l = p[0];
p[0] = q;
pushup(tr[p[0]].l);
pushup(p[0]);
}
static void insert(int[] p, int key) {
if (p[0] == 0) {
p[0] = get_node(key);
} else if (tr[p[0]].key == key) {
tr[p[0]].cnt++;
} else if (tr[p[0]].key > key) {
insert(new int[]{tr[p[0]].l}, key);
if (tr[tr[p[0]].l].val > tr[p[0]].val) zig(p);
} else {
insert(new int[]{tr[p[0]].r}, key);
if (tr[tr[p[0]].r].val > tr[p[0]].val) zag(p);
}
pushup(p[0]);
}
static void remove(int[] p, int key) {
if (p[0] == 0) return;
if (tr[p[0]].key == key) {
if (tr[p[0]].cnt > 1) tr[p[0]].cnt--;
else if (tr[p[0]].l != 0 || tr[p[0]].r != 0) {
if (tr[p[0]].r == 0 || tr[tr[p[0]].l].val > tr[tr[p[0]].r].val) {
zig(p);
remove(new int[]{tr[p[0]].r}, key);
} else {
zag(p);
remove(new int[]{tr[p[0]].l}, key);
}
} else p[0] = 0;
} else if (tr[p[0]].key > key) remove(new int[]{tr[p[0]].l}, key);
else remove(new int[]{tr[p[0]].r}, key);
pushup(p[0]);
}
static int get_rank_by_key(int p, int key) {
if (p == 0) return 0;
if (tr[p].key == key) return tr[tr[p].l].size + 1;
if (tr[p].key > key) return get_rank_by_key(tr[p].l, key);
return tr[tr[p].l].size + tr[p].cnt + get_rank_by_key(tr[p].r, key);
}
static int get_key_by_rank(int p, int rank) {
if (p == 0) return INF;
if (tr[tr[p].l].size >= rank) return get_key_by_rank(tr[p].l, rank);
if (tr[tr[p].l].size + tr[p].cnt >= rank) return tr[p].key;
return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt);
}
static int get_prev(int p, int key) {
if (p == 0) return -INF;
if (tr[p].key >= key) return get_prev(tr[p].l, key);
return Math.max(tr[p].key, get_prev(tr[p].r, key));
}
static int get_next(int p, int key) {
if (p == 0) return INF;
if (tr[p].key <= key) return get_next(tr[p].r, key);
return Math.min(tr[p].key, get_next(tr[p].l, key));
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
build();
n = scanner.nextInt();
while (n-- > 0) {
int opt = scanner.nextInt();
int x = scanner.nextInt();
if (opt == 1) insert(new int[]{root}, x);
else if (opt == 2) remove(new int[]{root}, x);
else if (opt == 3) System.out.println(get_rank_by_key(root, x) - 1);
else if (opt == 4) System.out.println(get_key_by_rank(root, x + 1));
else if (opt == 5) System.out.println(get_prev(root, x));
else System.out.println(get_next(root, x));
}
}
}