题目描述
给定三个整数数组
A=[A1,A2,…AN],
B=[B1,B2,…BN],
C=[C1,C2,…CN],
请你统计有多少个三元组 (i,j,k) 满足:
1≤i,j,k≤N
Ai<Bj<Ck
输入格式
第一行包含一个整数 N。
第二行包含 N 个整数 A1,A2,…AN。
第三行包含 N 个整数 B1,B2,…BN。
第四行包含 N 个整数 C1,C2,…CN。
输出格式
一个整数表示答案。
数据范围
1≤N≤105,
0≤Ai,Bi,Ci≤105
样例
输入样例:
3
1 1 1
2 2 2
3 3 3
输出样例:
27
算法1
(二分) $O(nlogn)$
对于b数组中每一个元素, 找到在a数组中比b[i]小的最后一个数, 若找到的下标为l, 有两种情况:
若找到的那个数比b[i]小, 则计算a数组中有多少个数比b[i]小, 公式为: l + 1(因为下标从0开始)
若找到的那个数比b[i]大, 则a数组中比b[i]小的个数为0
对于b数组中每一个元素, 找到在c数组中比b[i]大的第一个数, 若找到的下标为l, 有两种情况:
若找到的那个数比b[i]大, 则计算c数组中有多少个数比b[i]大, 公式为: n - l
若找到的那个数比b[i]小, 则a数组中比b[i]大的个数为0
C++ 代码
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
int n;
int a[N], b[N], c[N];
int main() {
scanf("%d", &n);
for(int i = 0; i < n; i++) {
scanf("%d", &a[i]);
}
for(int i = 0; i < n; i++) {
scanf("%d", &b[i]);
}
for(int i = 0; i < n; i++) {
scanf("%d", &c[i]);
}
sort(a, a + n);
sort(c, c + n);
LL res = 0;
for(int i = 0; i < n; i++) {
//找出a数组中比b[i]小的最后一个数
int l = 0, r = n - 1;
while(l < r) {
int mid = l + r + 1 >> 1;
if(a[mid] < b[i]) l = mid;
else r = mid - 1;
}
int ta = 0;
//若找到了a数组中比b[i]小的第一个数, 则计算一共几个数; 否则a数组中没有比b[i]小的数
if(a[l] < b[i]) {
ta = l + 1;
}
//找出c数组中比b[i]大的第一个数
l = 0, r = n - 1;
while(l < r) {
int mid = l + r >> 1;
if(c[mid] > b[i]) r = mid;
else l = mid + 1;
}
int tc = 0;
//若找到了c数组中比b[i]大的第一个数, 则计算一共几个数; 否则c数组中没有比b[i]大的数
if(c[l] > b[i]) {
tc = n - l;
}
res += (LL) ta * tc;
}
cout << res;
return 0;
}