题目描述
线段树求多个矩形合并后的周长(要求举行的各边平行于坐标轴)
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;
const int N = 200010;
typedef long long LL;
struct Segment {
int x, y1, y2;
int k;
bool operator< (const Segment& ver) const
{
if(x != ver.x) return x < ver.x;
else return k > ver.k;
}
}seg1[N << 1], seg2[N << 1];
struct Node{
int l, r;
int cnt;
int len;
}tr1[N << 3], tr2[N << 3];
vector<int> alls1, alls2;
int find1(int x)
{
return lower_bound(alls1.begin(), alls1.end(), x) - alls1.begin();
}
int find2(int x)
{
return lower_bound(alls2.begin(), alls2.end(), x) - alls2.begin();
}
int n;
void pushup(Node tr[], int u)
{
if(tr[u].cnt)
{
if(tr == tr1)
{
tr[u].len = alls1[tr[u].r + 1] - alls1[tr[u].l];
}else if(tr == tr2)
{
tr[u].len = alls2[tr[u].r + 1] - alls2[tr[u].l];
}
}
else if(tr[u].l != tr[u].r) tr[u].len = tr[u << 1].len + tr[u << 1 | 1].len;
else tr[u].len = 0;
}
void build(Node tr[], int u, int l, int r)
{
tr[u].l = l, tr[u].r = r, tr[u].cnt = 0, tr[u].len = 0;
if(l == r) return ;
int mid = l + r >> 1;
build(tr, u << 1, l, mid), build(tr, u << 1 | 1, mid + 1, r);
}
void modify(Node tr[], int u, int l, int r, int k)
{
if(tr[u].l >= l && tr[u].r <= r)
{
// if(k == 1) tr[u].cnt = 1;
// else tr[u].cnt = 0;
tr[u].cnt += k;
pushup(tr, u);
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(tr, u << 1, l, r, k);
if(r > mid) modify(tr, u << 1 | 1, l, r, k);
pushup(tr, u);
}
int solve1()
{
int res = 0, last = 0;
for(int i = 0 ; i < 2 * n ; i ++ )
{
modify(tr1, 1, find1(seg1[i].y1), find1(seg1[i].y2) - 1, seg1[i].k);
res += abs((tr1[1].len - last));
last = tr1[1].len;
// printf("res = %d, last = %d\n", res, last);
}
return res;
}
int solve2()
{
int res = 0, last = 0;
for(int i = 0 ; i < 2 * n ; i ++ )
{
modify(tr2, 1, find2(seg2[i].y1), find2(seg2[i].y2) - 1, seg2[i].k);
res += abs((tr2[1].len - last));
last = tr2[1].len;
}
return res;
}
int main(void)
{
while(scanf("%d", &n) != EOF)
{
alls1.clear(); alls2.clear();
for(int i = 0, idx1 = 0, idx2 = 0 ; i < n; i ++ )
{
int x1, y1, x2, y2;
scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
seg1[idx1 ++] = {x1, y1, y2, 1};
seg1[idx1 ++] = {x2, y1, y2, -1};
alls1.push_back(y1);
alls1.push_back(y2); // 计算竖直方向的周长
seg2[idx2 ++] = {y1, x1, x2, 1};
seg2[idx2 ++] = {y2, x1, x2, -1};
alls2.push_back(x1);
alls2.push_back(x2); // 计算水平方向的周长
}
build(tr1, 1, 0, alls1.size());
build(tr2, 1, 0, alls2.size());
sort(seg1, seg1 + 2 * n);
sort(alls1.begin(), alls1.end());
alls1.erase(unique(alls1.begin(), alls1.end()), alls1.end());
sort(seg2, seg2 + 2 * n);
sort(alls2.begin(), alls2.end());
alls2.erase(unique(alls2.begin(), alls2.end()), alls2.end());
int tot1 = solve1();
int tot2 = solve2();
//printf("%d %d %d\n", tot1, tot2, tot1 + tot2);
printf("%d\n", tot1 + tot2);
}
return 0;
}
/*
2
0 0 1 1
1 0 2 1
*/