题意:
随机生成一个数组,有以下几种操作
操作F x:在未被删除的数组查询x的后继
操作D x:删除x的后继
操作C x;求所有未被删除且小于等于x的数的总和
操作R x:将所有小于等于x的被删除的数恢复
算法1(树状数组+并查集)
用树状数组存储被删除的数,logn复杂度内实现修改和求前缀和
用并查集快速找到当前数后没被删除的数
题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=6902
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<cmath>
#include<string>
#include<cstring>
#include<bitset>
#include<vector>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<iomanip>
#include<algorithm>
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"<<endl;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define dbg(x) cerr << #x " = " << (x) << endl
#define endl "\n"
#define int long long
#define PI acos(-1)
//CLOCKS_PER_SEC clock()函数每秒执行次数
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 1e6+5,M = N * 2;
int mod = 1e9 +7;
int m,k,S,T;
int n;
unsigned long long k1, k2;
unsigned long long xorShift128Plus() {
unsigned long long k3 = k1, k4 = k2;
k1 = k4;
k3 ^= k3 << 23;
k2 = k3 ^ k4 ^ (k3 >> 17) ^ (k4 >> 26);
return k2 + k4;
}
long long a[1000015];
void gen() {
scanf("%lld %llu %llu", &n, &k1, &k2);
for (int i = 1; i <= n; i++) {
a[i] = xorShift128Plus() % 999999999999 + 1;
}
}
int tr[N],sum[N],p[N];
int lowbit(int x){
return x & -x;
}
void add(int x,int v){
for(int i = x ; i <= n ; i += lowbit(i)) tr[i] += v;
}
int query(int x){
int res = 0;
for(int i = x ; i ; i -= lowbit(i)) res += tr[i];
return res;
}
void init(){
sort(a + 1,a + n + 1);
for(int i = 1 ; i <= n ; ++i){
p[i] = i;
sum[i] = sum[i - 1] + a[i];
tr[i] = 0;
}
p[n + 1] = n + 1;
}
int find(int x){
if(p[x] != x) return p[x] = find(p[x]);
return x;
}
void solve(){
gen();
init();
scanf("%lld",&m);
while(m--){
char op[10];
int x;
scanf("%s%lld",op,&x);
if(op[0] == 'F'){
//找x的后继
int k = lower_bound(a + 1,a + n + 1,x) - a;
k = find(k);
//无后继
if(k == n + 1) puts("1000000000000");
else printf("%lld\n",a[k]);
}
if(op[0] == 'D'){
//找x的前驱
int k = lower_bound(a + 1,a + n + 1,x) - a;
k = find(k);
//无后继
if(k == n + 1) continue;
//将删除的数更新到树状数组
add(k,a[k]);
//当前点的根节点指向下一个点的根节点
p[k] = find(k + 1);
}
if(op[0] == 'C'){
//找前驱
int k = upper_bound(a + 1,a + n + 1,x) - a - 1;
//无前驱
if(k == 0) puts("0");
//总和减去删去的点的和
else printf("%lld\n",sum[k] - query(k));
}
if(op[0] == 'R'){
//找前驱
int k = upper_bound(a + 1,a + n + 1,x) - a - 1;
//无前驱
if(k == 0) continue;
//将所有下标在k之前的删去的点复原
for(int i = 1 ; i <= k ; ++i){
if(p[i] == i) continue;
p[i] = i;
add(i,-a[i]);
}
}
}
}
signed main(){
int tt;
scanf("%lld",&tt);
while(tt--)
solve();
return 0;
}
算法2(splay):
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<cmath>
#include<string>
#include<cstring>
#include<bitset>
#include<vector>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<iomanip>
#include<algorithm>
#define dbgfull(x) cerr << #x << " = " << x << " (line " << __LINE__ << ")"<<endl;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define dbg(x) cerr << #x " = " << (x) << endl
#define endl "\n"
#define int long long
#define PI acos(-1)
//CLOCKS_PER_SEC clock()函数每秒执行次数
using namespace std;
const int INF = 0x3f3f3f3f3f3f3f3f;
const int N = 5e6+10000;
int mod = 1e9 +7;
int m,k,S,T,idx;
int n;
unsigned long long k1, k2;
unsigned long long xorShift128Plus() {
unsigned long long k3 = k1, k4 = k2;
k1 = k4;
k3 ^= k3 << 23;
k2 = k3 ^ k4 ^ (k3 >> 17) ^ (k4 >> 26);
return k2 + k4;
}
long long a[1000015];
void gen() {
scanf("%lld %llu %llu", &n, &k1, &k2);
for (int i = 1; i <= n; i++) {
a[i] = xorShift128Plus() % 999999999999 + 1;
}
}
struct node{
int s[2],p;
int v;
int sum;
void init(int _v,int _p){
s[0] = s[1] = 0;
p = _p;
v = sum = _v;
}
}tr[N];
signed root;
void pushup(int u){
tr[u].sum = tr[tr[u].s[0]].sum + tr[u].v + tr[tr[u].s[1]].sum;
}
void rotate(int x){
int y = tr[x].p,z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x,tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1],tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y,tr[y].p = x;
pushup(y),pushup(x);
}
void splay(int x,int k){
while(tr[x].p != k){
int y = tr[x].p,z = tr[y].p;
if(z != k){
if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k) root = x;
}
void insert(int v){
int u = root,p = 0;
while(u) p = u,u = tr[u].s[v > tr[u].v];
u = ++idx;
if(p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v,p);
splay(u,0);
}
int build(int l,int r,int p){
int mid = l + r >> 1;
int u = ++idx;
tr[u].init(a[mid],p);
if(mid > l) tr[u].s[0] = build(l,mid - 1,u);
if(mid < r) tr[u].s[1] = build(mid + 1,r,u);
pushup(u);
return u;
}
int get_pre(int v){
int u = root,res = -INF;
while(u){
if(v >= tr[u].v) res = tr[u].v,u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
int get_nxt(int v){
int u = root,res = INF;
while(u){
if(v <= tr[u].v) res = tr[u].v,u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
void update(int x){
int u = root;
while(u){
if(tr[u].v == x) break;
if(x > tr[u].v) u = tr[u].s[1];
else u = tr[u].s[0];
}
splay(u,0);
int l = tr[u].s[0],r = tr[u].s[1];
while(tr[l].s[1]) l = tr[l].s[1];
while(tr[r].s[0]) r = tr[r].s[0];
splay(l,0),splay(r,l);
tr[r].s[0] = 0;
pushup(r),pushup(l);
}
multiset<int> s;
void solve(){
gen();
root = 0;
idx = 0;
scanf("%lld",&m);
//加入两个哨兵
a[0] = -INF,a[n + 1] = INF;
sort(a,a + n + 2);
//必须O(n)建树,每个单次插入会超时
root = build(0,n + 1,0);
s.clear();
s.insert(-INF);
s.insert(INF);
while(m--){
char op[2];
int x;
scanf("%s%lld",op,&x);
if(op[0] == 'F'){
//找x的后继
x = get_nxt(x);
if(x == INF) puts("1000000000000");
else printf("%lld\n",x);
}
if(op[0] == 'D'){
//找x的后继
x = get_nxt(x);
if(x == INF) continue;
update(x);
//将被删除的数放入s集合中
s.insert(x);
}
if(op[0] == 'C'){
x = get_pre(x);
if(x == -INF){
puts("0");
continue;
}
//找到当前数将当前数转到根节点
int u = root;
while(u){
if(tr[u].v == x) break;
if(x > tr[u].v) u = tr[u].s[1];
else u = tr[u].s[0];
}
splay(u,0);
//左儿子的和本身的和即为答案,要减去哨兵-INF
printf("%lld\n",tr[tr[root].s[0]].sum + tr[root].v + INF);
}
if(op[0] == 'R'){
while(1){
auto it = s.upper_bound(x);
it--;
if((*it) != -INF){
s.erase(it);
insert(*it);
}
else break;
}
}
}
}
signed main(){
int tt;
scanf("%lld",&tt);
while(tt--)
solve();
return 0;
}