什么是二分查找算法
典型的二分查找场景有:寻找一个数、寻找左侧边界、寻找右侧边界等等。看下面这个图你就大概懂了
思路看起来很简单,但是深入了解过的人都知道二分算法有很多的细节。连 Knuth 大佬(发明 KMP 算法的那位)都说:
Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky…
大概就是说:二分算法的思路很简单,但其中的细节是魔鬼。
这里不讨论到底有多么多的细节,直接上两套足以应对绝大部分二分问题的代码模版。
算法思路
对于整数二分,不管是寻找一个数、寻找左侧边界还是寻找右侧边界,可以抽象成这样:一个数列可以分为两部分:一部分满足某个性质和另一部分。如下图:
假设我们要二分出A点,此时我们需要有一个函数check(mid)
可以返回当前的mid是否满足这个性质。假设现在这个函数描述的是图中绿色部分的性质。那么,当函数return false
时,那么A点在mid的左侧。下一次二分只需要考虑l
到mid-1
的范围,更新范围的代码为r=mid-1
。当函数return true
时,那么A点就会在mid的右侧,下一次二分只需要考虑mid
到r
的范围,更新范围的代码为l=mid
。
这种情况的代码模版为:
int bsearch_1(int l, int r)
{
while (l < r)
{
// >>1为右移1位,等价于除以2
int mid = l + r + 1 >> 1;
if (check(mid)) l = mid;
else r = mid - 1;
}
return l;
}
假设现在我们要二分出B点,这时我们的函数check(mid)
应该与上面相反,描述的是红色部分的性质才能保证B点一直在二分的区间中,那么当函数return false
时,那么B点在mid的右侧。下一次二分只需要考虑mid+1
到r
的范围,更新范围的代码为l=mid+1
。当函数return true
时,那么B点就会在mid的左侧(也可能就在mid上),下一次二分只需要考虑l
到mid
的范围,更新范围的代码为r=mid
。
这种情况的代码模版为:
int bsearch_2(int l, int r)
{
while (l < r)
{
int mid = l + r >> 1;
if (check(mid)) r = mid;
else l = mid + 1;
}
return l;
}
两个模版的区别在于模版1在取mid的时候要+1
。这是因为,在第一种情况下存在这么一种情况。假设某一次二分的时候l=r-1
比如r=3,l=4
那么mid=l+r>>2
的话mid=3
就等于l
,当更新范围时l=mid=l
就会陷入死循环。
例题
给定一个按照升序排列的长度为n的整数数组,以及 q 个查询。
对于每个查询,返回一个元素k的起始位置和终止位置(位置从0开始计数)。
如果数组中不存在该元素,则返回“-1 -1”。
输入格式
第一行包含整数n和q,表示数组长度和询问个数。
第二行包含n个整数(均在1~10000范围内),表示完整数组。
接下来q行,每行包含一个整数k,表示一个询问元素。
输出格式
共q行,每行包含两个整数,表示所求元素的起始位置和终止位置。
如果数组中不存在该元素,则返回“-1 -1”。
数据范围
1≤n≤100000
1≤q≤10000
1≤k≤10000
输入样例:
6 3
1 2 2 3 3 4
3
4
5
输出样例:
3 4
5 5
-1 -1
解题思路
这题就可以看成是找左边界和右边界的问题,找左分界点(即x的起点)时check(mid)
写成为arr[mid]>=x
,因为左分界点后面的数都是>=x
的,同理找右分界点(即x的终止位置)时check(mid)
可以写成为arr[mid]<=x
。
java
代码
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int q = scanner.nextInt();
int[] arr = new int[n];
for (int i = 0; i < n; i++) {
arr[i] = scanner.nextInt();
}
while (q-- > 0) {
int k = scanner.nextInt();
bsearch(arr, 0, n - 1, k);
}
scanner.close();
}
private static void bsearch(int[] arr, int l, int r, int k) {
while (l < r) {
// 找左分界点
int mid = l + r >> 1;
if (arr[mid] >= k) {
r = mid;
} else {
l = mid + 1;
}
}
if (arr[l] != k)
System.out.println("-1 -1");
else {
System.out.print(l + " ");
l = 0;
r = arr.length - 1;
while (l < r) {
// 找右分界点
int mid = l + r + 1 >> 1;
if (arr[mid] <= k) {
l = mid;
} else {
r = mid - 1;
}
}
System.out.println(l);
}
}
}