题意
题意是给两个不降序的数组a,b,然后让我们找a,b两个数组中所有元素的中位数。
解题
设n为a,b两数组元素的总个数。如果n是奇数,那么就返回第 n / 2 + 1 小的数;如果n为偶数,那么就返回第n / 2小和第n / 2 + 1小的数的平均值。
所以说,这道题的本质,就是让我们在两个有序数组中找到第k小的数。
算法1 O((n + m)log(n + m))
所以说,最直接的解法,就是将两个数组的元素合并到一个数组中,然后排序,找到对应的数值,返回结果。
以下的算法2和3都是为了帮助理解算法4
算法2 O(n + m)
我们也可以不排序,因为a,b都是有序的,所以我们可以遍历a和b,然后找到第k个数,代码如下。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size() + nums2.size();
if(n % 2) return findKth(nums1, nums2, n / 2 + 1);
int a = findKth(nums1, nums2, n / 2);
int b = findKth(nums1, nums2, n / 2 + 1);
cout << a << b << endl;
return (a + b) / 2.0;
}
int findKth(vector<int> &a, vector<int> &b, int k){
if(a.size() > b.size()) return findKth(b, a, k);
int i = 0, j = 0;
while(-- k){//使min(a[i], b[j])是第t小的数
if(i == a.size()) j ++ ;
else if(j == b.size()) i ++ ;
else if(a[i] > b[j]) j ++ ;
else i ++ ;
}
if(i == a.size()) return b[j];
if(j == b.size()) return a[i];
return min(a[i], b[j]);
}
};
算法三 同上,递归实现
我们可以换一个思路,算法2中的每次循环,就相当于在a和b数组中淘汰掉一个剩余的数中最小数,所以找到第k小的数就需要淘汰掉k - 1个数,此时剩余的数中最小的数就是第k小的数。基于这个思路,我们可以递归的写出这段代码。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size() + nums2.size();
if(n % 2) return findKth(nums1, 0, nums2, 0, n / 2 + 1);
int a = findKth(nums1, 0, nums2, 0, n / 2);
int b = findKth(nums1, 0, nums2, 0, n / 2 + 1);
return (a + b) / 2.0;
}
int findKth(vector<int> &a, int i, vector<int> &b, int j, int k){
if(a.size() - i > b.size() - j) return findKth(b, j, a, i, k); //保持第二个数组剩余元素更多
if(i == a.size()) return b[j + k - 1]; //a数组中没有元素了,返回数组b中第k小的数
if(k == 1) return min(a[i], b[j]); //找到从a[i], b[j]开始最小的数
if(a[i] < b[j]) return findKth(a, i + 1, b, j, k - 1);//在数组a中淘汰一个数
return findKth(a, i, b, j + 1, k - 1); //在数组b中淘汰一个数
}
};
算法四 O(log(m + n))
要降低算法的时间复杂度,那么就要在每次循环中淘汰掉尽可能多的数。我们要找到的是第k小的数,那么我们就可以在每次循环中最多淘汰掉k / 2个数,这样我们的时间复杂度就可以达到log级别。
我们淘汰掉a和b数组中最大值较小的k / 2个数,这样我们就可以保证第k小的数还是在剩余的数中。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size() + nums2.size();
if(n % 2) return findKth(nums1, 0, nums2, 0, n / 2 + 1);
int a = findKth(nums1, 0, nums2, 0, n / 2);
int b = findKth(nums1, 0, nums2, 0, n / 2 + 1);
cout << a << b << endl;
return (a + b) / 2.0;
}
int findKth(vector<int> &a, int i, vector<int> &b, int j, int k){
if(a.size() - i > b.size() - j) return findKth(b, j, a, i, k);
if(i == a.size()) return b[j + k - 1];
if(k == 1) return min(a[i], b[j]);
//数组a较小,防止数组si越界,因为 k <= 当前剩余元素的个数, 所以较长的数组不会越界
int si = min((int)a.size(), i + k / 2), sj = j + k / 2;
//si和sj前面的数,都是可能会被淘汰的
if(a[si - 1] < b[sj - 1]) return findKth(a, si, b, j, k - (si - 1 - i + 1)); //淘汰较小的一组数
return findKth(a, i, b, sj, k - k / 2);
}
};