分治算法思维
一提到分治,很多人脑子里第一个蹦出来的就是“归并排序”。没错,归并排序就是分治思想最经典的体现。但如果只停留在“哦,分治就是归并排序”,那格局就小了。
分治的精髓在于,你别总想着一口吃成个胖子。一个大问题,咋看咋没思路,怎么办?把它搞成两个规模小一点的、但是同样性质的子问题,如果子问题还大,就继续往下搞,直到问题规模小到你一眼就能看出的答案。然后,再想办法把这些小子问题的答案,一层一层合并成大问题的答案。
梦开始的地方:归并排序
问题:给你一个无序数组,让它变有序。
解法一:无脑暴力法
这个就不多说了,选择、冒泡、插入,写个两层 for 循环,O(N^2) 的复杂度。代码就不贴了,有点占地方,思路也很直接,就是挨个比较。对于初学者来说,这当然是第一想法,但我们得追求点儿不一样的东西。
解法二:分治思想登场
我的第一反应是,一个数组排序不好搞,那半个数组呢?如果我把左半边排好序,右半边也排好序,然后想办法把这两个有序的半边合并成一个大的有序数组,这事儿不就成了吗?
这思路很“分治”。
分(Divide):把数组
arr[L...R]从中间mid一分为二,变成arr[L...mid]和arr[mid+1...R]。治(Conquer):递归地去把左右两个子数组排好序。递归的尽头是啥?就是子数组只剩一个数了,一个数天然有序,直接返回就行。
合(Combine):这是最关键的一步,也是归并排序的灵魂——
merge操作。现在我手上有两个已经排好序的子数组,怎么把它们合并成一个大的有序数组?
merge 的过程其实很简单。搞个辅助数组 help,大小和要合并的范围一样大。再用两个指针 p1 和 p2,一个指向左边有序数组的开头,一个指向右边有序数组的开头。
然后俩指针开始 PK:
谁指向的数小,就把谁的数拷贝到
help数组里,然后这个指针往后挪一位。一直 PK,直到某一个指针越界了。
把另一个没越界的指针剩下的所有数,直接依次拷贝到
help数组的末尾。最后,把
help数组里的数再拷贝回原数组arr的对应位置。
整个过程,merge 操作的复杂度是 O(N),因为每个元素都只进出 help 数组一次。递归的深度是 logN。所以总的复杂度就是 O(N*logN)。
// 解法二和解法三其实是同一思想,这里合并讲解,代码是最终形态
public class MergeSort {
public static void sort(int[] arr) {
if (arr == null || arr.length < 2) {
return;
}
process(arr, 0, arr.length - 1);
}
// 递归过程:让 arr[L...R] 变有序
private static void process(int[] arr, int L, int R) {
if (L == R) { // base case: 只剩一个数,天然有序
return;
}
int mid = L + ((R - L) >> 1); // 防止溢出
process(arr, L, mid);
process(arr, mid + 1, R);
merge(arr, L, mid, R);
}
// 合并操作
private static void merge(int[] arr, int L, int M, int R) {
int[] help = new int[R - L + 1];
int i = 0;
int p1 = L;
int p2 = M + 1;
while (p1 <= M && p2 <= R) {
help[i++] = arr[p1] <= arr[p2] ? arr[p1++] : arr[p2++];
}
// p1 或 p2 必有一个越界,另一个没越界
while (p1 <= M) {
help[i++] = arr[p1++];
}
while (p2 <= R) {
help[i++] = arr[p2++];
}
// 把 help 数组拷回 arr
for (i = 0; i < help.length; i++) {
arr[L + i] = help[i];
}
}
}
解法三:非递归实现
递归有时候会因为调用栈太深而出问题,虽然在排序这个场景下一般不会。但作为思维提升,我们得想想怎么把它改成非递归。
递归的本质是啥?先深入到底,再回溯合并。 process(0, 7) -> process(0, 3) -> process(0, 1) -> process(0, 0) & process(1, 1) -> merge(0, 0, 1)
我们换个思路,不从上往下“分”,而是从下往上“合”。
先让步长
step = 1。把数组里每 1 个元素看成一组,然后两两合并。arr[0]和arr[1]合并,arr[2]和arr[3]合并...再让步长
step = 2。现在数组里每 2 个元素都是有序的了。我们把这长度为 2 的小组两两合并。arr[0..1]和arr[2..3]合并...再让步长
step = 4, 8, 16...,直到步长大于等于数组长度的一半。
这个过程,模拟的就是递归回溯时合并的顺序,但是用迭代实现了。代码稍微复杂一点,但思路很清晰。
public static void sortNonRecursive(int[] arr) {
if (arr == null || arr.length < 2) {
return;
}
int N = arr.length;
int step = 1; // 步长,代表当前小组的长度
while (step < N) {
int L = 0;
while (L < N) {
int M = L + step - 1;
if (M >= N) { // 左组不够了,直接break
break;
}
int R = Math.min(M + step, N - 1);
// 只有左组,没有右组的情况,在merge里处理,或者这里判断 R > M
merge(arr, L, M, R);
L = R + 1;
}
// 防止 step * 2 溢出
if (step > N / 2) {
break;
}
step <<= 1;
}
}
好了,归并排序讲完了。但如果分治就这点东西,那也太小看它了。归并排序的 merge 过程,其实提供了一个非常强大的能力:在一次合并中,你可以处理任何与左右两组数据相关的统计问题。
这才是杀手锏。下面看几个经典问题,体会一下这种思想的威力。
进阶一:求小和问题
问题:在一个数组中,一个数左边比它小的数的总和,叫数的小和。求一个数组的小和。 例子:[1, 3, 4, 2, 5]
1 左边没数,小和 0
3 左边比它小的数是 1,小和 1
4 左边比它小的数是 1, 3,小和 4
2 左边比它小的数是 1,小和 1
5 左边比它小的数是 1, 3, 4, 2,小和 10
总的小和 = 0 + 1 + 4 + 1 + 10 = 16
解法一:暴力 O(N^2)
不解释,两层循环,i 从 1 到 N-1,j 从 0 到 i-1,如果 arr[j] < arr[i] 就累加 arr[j]。
public static int smallSumBruteForce(int[] arr) {
if (arr == null || arr.length < 2) {
return 0;
}
int sum = 0;
for (int i = 1; i < arr.length; i++) {
for (int j = 0; j < i; j++) {
if (arr[j] < arr[i]) {
sum += arr[j];
}
}
}
return sum;
}
解法二:换个脑子想问题
暴力解法是以每个数 arr[i] 为主角,看它左边有多少贡献者。我们能不能换个角度?以 arr[j] 为主角,看它能为右边多少个数产生小和?
对于 arr[j],它右边有多少个数比它大?假设有 k 个。那么 arr[j] 产生的小和总贡献就是 arr[j] * k。
这个问题转换成了:求每个数右边有多少个数比它大。
这不还是 O(N^2) 吗?别急,merge 过程的机会来了。
在 merge(arr, L, M, R) 的时候,我们有排好序的左组 [L...M] 和右组 [M+1...R]。 当 p1 和 p2 PK 时,如果 arr[p1] < arr[p2],这意味着什么? 意味着 arr[p1] 这个数,比右组里从 p2 指针开始一直到末尾的所有数都要小!因为右组是有序的。
右组从 p2 到 R 一共有多少个数?R - p2 + 1 个。 所以,在这一瞬间,我们就找到了 arr[p1] 的 R - p2 + 1 个“右边比它大的数”。arr[p1] 产生的小和一下子就算出来了:arr[p1] * (R - p2 + 1)。
整个算法的流程就变成了:
写一个归并排序的架子。
在
merge函数里,当arr[p1] < arr[p2]时,把arr[p1] * (R - p2 + 1)累加到全局结果里。merge过程的其它部分(拷贝、排序)照旧。
解法三:分治 O(N*logN)
这就是上面思路的实现。整个过程和归并排序一模一样,只是在 merge 时顺手做了个统计。
public class SmallSum {
public static int getSmallSum(int[] arr) {
if (arr == null || arr.length < 2) {
return 0;
}
return process(arr, 0, arr.length - 1);
}
// 返回 arr[L...R] 排序且产生的小和
private static int process(int[] arr, int L, int R) {
if (L == R) {
return 0;
}
int mid = L + ((R - L) >> 1);
// 小和 = 左边产生的小和 + 右边产生的小和 + merge过程中产生的小和
return process(arr, L, mid)
+ process(arr, mid + 1, R)
+ merge(arr, L, mid, R);
}
private static int merge(int[] arr, int L, int M, int R) {
int[] help = new int[R - L + 1];
int i = 0;
int p1 = L;
int p2 = M + 1;
int res = 0; // 存储当前 merge 产生的小和
while (p1 <= M && p2 <= R) {
if (arr[p1] < arr[p2]) {
// 核心:arr[p1]比右边(R - p2 + 1)个数都小
res += arr[p1] * (R - p2 + 1);
help[i++] = arr[p1++];
} else { // arr[p1] >= arr[p2]
help[i++] = arr[p2++];
}
}
while (p1 <= M) {
help[i++] = arr[p1++];
}
while (p2 <= R) {
help[i++] = arr[p2++];
}
for (i = 0; i < help.length; i++) {
arr[L + i] = help[i];
}
return res;
}
}
看到没?同样是 O(N*logN) 的复杂度,我们不仅排了序,还把小和问题解决了。这就是分治思想的威力。
进阶二:逆序对问题
问题:在一个数组中,如果 i < j 且 arr[i] > arr[j],那么 (arr[i], arr[j]) 就构成一个逆序对。求一个数组中逆序对的总数。 例子:[7, 5, 6, 4] 逆序对有:(7, 5), (7, 6), (7, 4), (5, 4), (6, 4),共 5 个。
解法一:暴力 O(N^2)
老朋友了。两层循环,i 从 0 到 N-2,j 从 i+1 到 N-1,如果 arr[i] > arr[j],计数器加一。
解法二 & 解法三:归并分治 O(N*logN)
有了小和问题的经验,这个问题是不是感觉有点熟悉? 一个逆序对 (arr[i], arr[j]),i < j, arr[i] > arr[j]。 这不就是“一个数右边有多少个数比它小”吗?
我们再次回到 merge(arr, L, M, R) 过程。 当 p1 和 p2 PK 时:
如果
arr[p1] <= arr[p2],说明arr[p1]和右边的数不构成逆序对,正常拷贝arr[p1]就行。如果
arr[p1] > arr[p2],这意味着什么? 意味着左组里,从p1开始一直到末尾M的所有数,都比arr[p2]要大!因为左组是有序的。 左组从p1到M一共有多少个数?M - p1 + 1个。 所以,在这一瞬间,我们一下子就找到了M - p1 + 1个逆序对,它们都以arr[p2]作为那个较小的数。
整个算法的流程就变成了:
写一个归并排序的架子。
在
merge函数里,当arr[p1] > arr[p2]时,把M - p1 + 1累加到全局结果里,然后拷贝arr[p2]。merge过程的其它部分照旧。
代码和小和问题几乎一样,就是计算贡献的地方逻辑稍微变了一下。
public class ReversePairs {
public static int countReversePairs(int[] arr) {
if (arr == null || arr.length < 2) {
return 0;
}
return process(arr, 0, arr.length - 1);
}
// 返回 arr[L...R] 排序且产生的逆序对数量
private static int process(int[] arr, int L, int R) {
if (L == R) {
return 0;
}
int mid = L + ((R - L) >> 1);
return process(arr, L, mid)
+ process(arr, mid + 1, R)
+ merge(arr, L, mid, R);
}
private static int merge(int[] arr, int L, int M, int R) {
int[] help = new int[R - L + 1];
int i = 0;
int p1 = L;
int p2 = M + 1;
int res = 0;
while (p1 <= M && p2 <= R) {
if (arr[p1] > arr[p2]) {
// 核心:左组从p1到M的所有数都比arr[p2]大
res += (M - p1 + 1);
help[i++] = arr[p2++];
} else { // arr[p1] <= arr[p2]
help[i++] = arr[p1++];
}
}
// ... 剩余部分和归并排序的merge完全一样 ...
while (p1 <= M) { help[i++] = arr[p1++]; }
while (p2 <= R) { help[i++] = arr[p2++]; }
for (i = 0; i < help.length; i++) { arr[L + i] = help[i]; }
return res;
}
}
终极挑战:区间和的个数(LeetCode 327)
问题:给你一个整数数组 nums,一个下限 lower 和一个上限 upper。求有多少个区间和 S(i, j)(i 到 j 的和)满足 lower <= S(i, j) <= upper。
解法一:暴力 O(N^3)
三层循环。i 循环区间起点,j 循环区间终点,k 循环计算 i 到 j 的和。然后判断是否在 [lower, upper] 范围内。太慢了。
解法二:前缀和优化 O(N^2)
我们可以预处理一个前缀和数组 preSum,preSum[k] 表示 nums[0...k-1] 的和。 那么区间和 S(i, j) 就可以在 O(1) 内算出:preSum[j+1] - preSum[i]。 这样,两层循环 i 和 j,就能算出所有区间和,总复杂度 O(N^2)。在很多场合下还是会超时。
public int countRangeSumN2(int[] nums, int lower, int upper) {
int n = nums.length;
long[] preSum = new long[n + 1];
for (int i = 0; i < n; i++) {
preSum[i + 1] = preSum[i] + nums[i];
}
int count = 0;
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
long sum = preSum[j + 1] - preSum[i];
if (sum >= lower && sum <= upper) {
count++;
}
}
}
return count;
}
解法三:归并分治 O(N*logN)
O(N^2) 的瓶颈在于,对于每个 j,我们都要回头去遍历所有的 i。这和前面两个问题的暴力解法何其相似。 分治思想是不是又能派上用场了?
我们的问题是要求满足 lower <= preSum[j+1] - preSum[i] <= upper 的 (i, j) 对数。 变形一下,对于每个 j,我们要找有多少个 i <= j,使得 preSum[i] 满足: preSum[j+1] - upper <= preSum[i] <= preSum[j+1] - lower
现在,问题变成了:对于 preSum 数组中的每个数 preSum[k],它前面有多少个数落在 [preSum[k] - upper, preSum[k] - lower] 这个范围里。
这不就是前面问题的翻版吗?"右边的数"看"左边的数"。我们立刻想到对 preSum 数组进行归并排序。
在 merge(preSum, L, M, R) 的过程中,对于右组的每一个数 preSum[p2](p2 从 M+1 到 R),我们要在左组 [L...M] 中,找到有多少个数 preSum[p1] 满足 preSum[p2] - upper <= preSum[p1] <= preSum[p2] - lower。
因为左组 [L...M] 是有序的,所以对于固定的 preSum[p2],我们可以在左组里用二分查找,或者更优的办法——滑动窗口来找到满足条件的 p1 的范围。
具体操作: 对于右组的 preSum[p2]:
在左组中,找到第一个大于等于
preSum[p2] - upper的位置,记为windowL。在左组中,找到第一个大于
preSum[p2] - lower的位置,记为windowR。那么,左组中满足条件的数的范围就是
[windowL, windowR),总共有windowR - windowL个。把这个数量累加到结果中。
最妙的是,当 p2 从 M+1 往 R 移动时,preSum[p2] 是递增的(因为右组部分有序),所以 preSum[p2] - upper 和 preSum[p2] - lower 也是递增的。这意味着 windowL 和 windowR 在左组里也只会单调向右移动!
所以,在 merge 的过程中,我们只需要一个 for 循环遍历右组的所有数,内部的 windowL 和 windowR 指针不需要回溯,整个统计过程的复杂度是 O(N),而不是 O(N*logN)。
因此,merge 函数总复杂度是 O(N)(统计O(N) + 排序合并O(N)),递归总复杂度还是 O(N*logN)。
public class CountRangeSum {
public int countRangeSum(int[] nums, int lower, int upper) {
if (nums == null || nums.length == 0) {
return 0;
}
long[] preSum = new long[nums.length + 1];
for (int i = 0; i < nums.length; i++) {
preSum[i + 1] = preSum[i] + nums[i];
}
return process(preSum, 0, preSum.length - 1, lower, upper);
}
private int process(long[] sum, int L, int R, int lower, int upper) {
if (L == R) {
return 0;
}
int M = L + ((R - L) >> 1);
return process(sum, L, M, lower, upper)
+ process(sum, M + 1, R, lower, upper)
+ merge(sum, L, M, R, lower, upper);
}
private int merge(long[] sum, int L, int M, int R, int lower, int upper) {
int count = 0;
// [windowL, windowR) 是满足条件的左组范围
int windowL = L;
int windowR = L;
// 对右组的每个数 sum[p2],在左组中找符合条件的
for (int p2 = M + 1; p2 <= R; p2++) {
long min = sum[p2] - upper;
long max = sum[p2] - lower;
while (windowL <= M && sum[windowL] < min) {
windowL++;
}
while (windowR <= M && sum[windowR] <= max) {
windowR++;
}
count += (windowR - windowL);
}
// 正常的归并排序合并流程
long[] help = new long[R - L + 1];
int i = 0;
int p1 = L;
int p2 = M + 1;
while (p1 <= M && p2 <= R) {
help[i++] = sum[p1] <= sum[p2] ? sum[p1++] : sum[p2++];
}
while (p1 <= M) { help[i++] = sum[p1++]; }
while (p2 <= R) { help[i++] = sum[p2++]; }
for (i = 0; i < help.length; i++) { sum[L + i] = help[i]; }
return count;
}
}
总结
从归并排序,到小和,到逆序对,再到区间和,我们看到的是同一个“爹”生出来的不同“娃”。
分治的套路,特别是基于归并排序的套路,它的核心不在于排序本身,而在于 merge 这个环节提供了一个上帝视角。在这个环节,我们手握两个有序的子集合,可以非常高效地处理那些需要“左边”和“右边”数据进行比较和统计的问题。
下次再遇到一个问题,如果它的暴力解法是 O(N^2) 的,且涉及到 i 和 j(i<j)的某种关系,不妨想一想,这个问题能不能被扔进归并排序的 merge 过程里,顺便解决掉。