分治算法思维

MrHe··8 min read

一提到分治,很多人脑子里第一个蹦出来的就是“归并排序”。没错,归并排序就是分治思想最经典的体现。但如果只停留在“哦,分治就是归并排序”,那格局就小了。

分治的精髓在于,你别总想着一口吃成个胖子。一个大问题,咋看咋没思路,怎么办?把它搞成两个规模小一点的、但是同样性质的子问题,如果子问题还大,就继续往下搞,直到问题规模小到你一眼就能看出的答案。然后,再想办法把这些小子问题的答案,一层一层合并成大问题的答案。

梦开始的地方:归并排序

问题:给你一个无序数组,让它变有序。

解法一:无脑暴力法

这个就不多说了,选择、冒泡、插入,写个两层 for 循环,O(N^2) 的复杂度。代码就不贴了,有点占地方,思路也很直接,就是挨个比较。对于初学者来说,这当然是第一想法,但我们得追求点儿不一样的东西。

解法二:分治思想登场

我的第一反应是,一个数组排序不好搞,那半个数组呢?如果我把左半边排好序,右半边也排好序,然后想办法把这两个有序的半边合并成一个大的有序数组,这事儿不就成了吗?

这思路很“分治”。

  1. 分(Divide):把数组 arr[L...R] 从中间 mid 一分为二,变成 arr[L...mid]arr[mid+1...R]

  2. 治(Conquer):递归地去把左右两个子数组排好序。递归的尽头是啥?就是子数组只剩一个数了,一个数天然有序,直接返回就行。

  3. 合(Combine):这是最关键的一步,也是归并排序的灵魂——merge 操作。现在我手上有两个已经排好序的子数组,怎么把它们合并成一个大的有序数组?

merge 的过程其实很简单。搞个辅助数组 help,大小和要合并的范围一样大。再用两个指针 p1p2,一个指向左边有序数组的开头,一个指向右边有序数组的开头。

然后俩指针开始 PK:

  • 谁指向的数小,就把谁的数拷贝到 help 数组里,然后这个指针往后挪一位。

  • 一直 PK,直到某一个指针越界了。

  • 把另一个没越界的指针剩下的所有数,直接依次拷贝到 help 数组的末尾。

  • 最后,把 help 数组里的数再拷贝回原数组 arr 的对应位置。

整个过程,merge 操作的复杂度是 O(N),因为每个元素都只进出 help 数组一次。递归的深度是 logN。所以总的复杂度就是 O(N*logN)。

// 解法二和解法三其实是同一思想,这里合并讲解,代码是最终形态

public class MergeSort {

    public static void sort(int[] arr) {
        if (arr == null || arr.length < 2) {
            return;
        }
        process(arr, 0, arr.length - 1);
    }

    // 递归过程:让 arr[L...R] 变有序
    private static void process(int[] arr, int L, int R) {
        if (L == R) { // base case: 只剩一个数,天然有序
            return;
        }
        int mid = L + ((R - L) >> 1); // 防止溢出
        process(arr, L, mid);
        process(arr, mid + 1, R);
        merge(arr, L, mid, R);
    }

    // 合并操作
    private static void merge(int[] arr, int L, int M, int R) {
        int[] help = new int[R - L + 1];
        int i = 0;
        int p1 = L;
        int p2 = M + 1;

        while (p1 <= M && p2 <= R) {
            help[i++] = arr[p1] <= arr[p2] ? arr[p1++] : arr[p2++];
        }

        // p1 或 p2 必有一个越界,另一个没越界
        while (p1 <= M) {
            help[i++] = arr[p1++];
        }
        while (p2 <= R) {
            help[i++] = arr[p2++];
        }

        // 把 help 数组拷回 arr
        for (i = 0; i < help.length; i++) {
            arr[L + i] = help[i];
        }
    }
}

解法三:非递归实现

递归有时候会因为调用栈太深而出问题,虽然在排序这个场景下一般不会。但作为思维提升,我们得想想怎么把它改成非递归。

递归的本质是啥?先深入到底,再回溯合并。 process(0, 7) -> process(0, 3) -> process(0, 1) -> process(0, 0) & process(1, 1) -> merge(0, 0, 1)

我们换个思路,不从上往下“分”,而是从下往上“合”。

  1. 先让步长 step = 1。把数组里每 1 个元素看成一组,然后两两合并。arr[0]arr[1] 合并,arr[2]arr[3] 合并...

  2. 再让步长 step = 2。现在数组里每 2 个元素都是有序的了。我们把这长度为 2 的小组两两合并。arr[0..1]arr[2..3] 合并...

  3. 再让步长 step = 4, 8, 16...,直到步长大于等于数组长度的一半。

这个过程,模拟的就是递归回溯时合并的顺序,但是用迭代实现了。代码稍微复杂一点,但思路很清晰。

public static void sortNonRecursive(int[] arr) {
    if (arr == null || arr.length < 2) {
        return;
    }
    int N = arr.length;
    int step = 1; // 步长,代表当前小组的长度
    while (step < N) {
        int L = 0;
        while (L < N) {
            int M = L + step - 1;
            if (M >= N) { // 左组不够了,直接break
                break;
            }
            int R = Math.min(M + step, N - 1);
            // 只有左组,没有右组的情况,在merge里处理,或者这里判断 R > M
            merge(arr, L, M, R);
            L = R + 1;
        }
        // 防止 step * 2 溢出
        if (step > N / 2) {
            break;
        }
        step <<= 1;
    }
}

好了,归并排序讲完了。但如果分治就这点东西,那也太小看它了。归并排序的 merge 过程,其实提供了一个非常强大的能力:在一次合并中,你可以处理任何与左右两组数据相关的统计问题。

这才是杀手锏。下面看几个经典问题,体会一下这种思想的威力。


进阶一:求小和问题

问题:在一个数组中,一个数左边比它小的数的总和,叫数的小和。求一个数组的小和。 例子:[1, 3, 4, 2, 5]

  • 1 左边没数,小和 0

  • 3 左边比它小的数是 1,小和 1

  • 4 左边比它小的数是 1, 3,小和 4

  • 2 左边比它小的数是 1,小和 1

  • 5 左边比它小的数是 1, 3, 4, 2,小和 10

  • 总的小和 = 0 + 1 + 4 + 1 + 10 = 16

解法一:暴力 O(N^2)

不解释,两层循环,i 从 1 到 N-1,j 从 0 到 i-1,如果 arr[j] < arr[i] 就累加 arr[j]

public static int smallSumBruteForce(int[] arr) {
    if (arr == null || arr.length < 2) {
        return 0;
    }
    int sum = 0;
    for (int i = 1; i < arr.length; i++) {
        for (int j = 0; j < i; j++) {
            if (arr[j] < arr[i]) {
                sum += arr[j];
            }
        }
    }
    return sum;
}

解法二:换个脑子想问题

暴力解法是以每个数 arr[i] 为主角,看它左边有多少贡献者。我们能不能换个角度?以 arr[j] 为主角,看它能为右边多少个数产生小和?

对于 arr[j],它右边有多少个数比它大?假设有 k 个。那么 arr[j] 产生的小和总贡献就是 arr[j] * k

这个问题转换成了:求每个数右边有多少个数比它大

这不还是 O(N^2) 吗?别急,merge 过程的机会来了。

merge(arr, L, M, R) 的时候,我们有排好序的左组 [L...M] 和右组 [M+1...R]。 当 p1p2 PK 时,如果 arr[p1] < arr[p2],这意味着什么? 意味着 arr[p1] 这个数,比右组里从 p2 指针开始一直到末尾的所有数都要小!因为右组是有序的。

右组从 p2R 一共有多少个数?R - p2 + 1 个。 所以,在这一瞬间,我们就找到了 arr[p1]R - p2 + 1 个“右边比它大的数”。arr[p1] 产生的小和一下子就算出来了:arr[p1] * (R - p2 + 1)

整个算法的流程就变成了:

  1. 写一个归并排序的架子。

  2. merge 函数里,当 arr[p1] < arr[p2] 时,把 arr[p1] * (R - p2 + 1) 累加到全局结果里。

  3. merge 过程的其它部分(拷贝、排序)照旧。

解法三:分治 O(N*logN)

这就是上面思路的实现。整个过程和归并排序一模一样,只是在 merge 时顺手做了个统计。

public class SmallSum {

    public static int getSmallSum(int[] arr) {
        if (arr == null || arr.length < 2) {
            return 0;
        }
        return process(arr, 0, arr.length - 1);
    }

    // 返回 arr[L...R] 排序且产生的小和
    private static int process(int[] arr, int L, int R) {
        if (L == R) {
            return 0;
        }
        int mid = L + ((R - L) >> 1);
        // 小和 = 左边产生的小和 + 右边产生的小和 + merge过程中产生的小和
        return process(arr, L, mid)
                + process(arr, mid + 1, R)
                + merge(arr, L, mid, R);
    }

    private static int merge(int[] arr, int L, int M, int R) {
        int[] help = new int[R - L + 1];
        int i = 0;
        int p1 = L;
        int p2 = M + 1;
        int res = 0; // 存储当前 merge 产生的小和

        while (p1 <= M && p2 <= R) {
            if (arr[p1] < arr[p2]) {
                // 核心:arr[p1]比右边(R - p2 + 1)个数都小
                res += arr[p1] * (R - p2 + 1);
                help[i++] = arr[p1++];
            } else { // arr[p1] >= arr[p2]
                help[i++] = arr[p2++];
            }
        }

        while (p1 <= M) {
            help[i++] = arr[p1++];
        }
        while (p2 <= R) {
            help[i++] = arr[p2++];
        }

        for (i = 0; i < help.length; i++) {
            arr[L + i] = help[i];
        }
        return res;
    }
}

看到没?同样是 O(N*logN) 的复杂度,我们不仅排了序,还把小和问题解决了。这就是分治思想的威力。


进阶二:逆序对问题

问题:在一个数组中,如果 i < jarr[i] > arr[j],那么 (arr[i], arr[j]) 就构成一个逆序对。求一个数组中逆序对的总数。 例子:[7, 5, 6, 4] 逆序对有:(7, 5), (7, 6), (7, 4), (5, 4), (6, 4),共 5 个。

解法一:暴力 O(N^2)

老朋友了。两层循环,i 从 0 到 N-2,ji+1 到 N-1,如果 arr[i] > arr[j],计数器加一。

解法二 & 解法三:归并分治 O(N*logN)

有了小和问题的经验,这个问题是不是感觉有点熟悉? 一个逆序对 (arr[i], arr[j])i < j, arr[i] > arr[j]。 这不就是“一个数右边有多少个数比它小”吗?

我们再次回到 merge(arr, L, M, R) 过程。 当 p1p2 PK 时:

  • 如果 arr[p1] <= arr[p2],说明 arr[p1] 和右边的数不构成逆序对,正常拷贝 arr[p1] 就行。

  • 如果 arr[p1] > arr[p2],这意味着什么? 意味着左组里,从 p1 开始一直到末尾 M 的所有数,都比 arr[p2] 要大!因为左组是有序的。 左组从 p1M 一共有多少个数?M - p1 + 1 个。 所以,在这一瞬间,我们一下子就找到了 M - p1 + 1 个逆序对,它们都以 arr[p2] 作为那个较小的数。

整个算法的流程就变成了:

  1. 写一个归并排序的架子。

  2. merge 函数里,当 arr[p1] > arr[p2] 时,把 M - p1 + 1 累加到全局结果里,然后拷贝 arr[p2]

  3. merge 过程的其它部分照旧。

代码和小和问题几乎一样,就是计算贡献的地方逻辑稍微变了一下。

public class ReversePairs {

    public static int countReversePairs(int[] arr) {
        if (arr == null || arr.length < 2) {
            return 0;
        }
        return process(arr, 0, arr.length - 1);
    }

    // 返回 arr[L...R] 排序且产生的逆序对数量
    private static int process(int[] arr, int L, int R) {
        if (L == R) {
            return 0;
        }
        int mid = L + ((R - L) >> 1);
        return process(arr, L, mid)
                + process(arr, mid + 1, R)
                + merge(arr, L, mid, R);
    }

    private static int merge(int[] arr, int L, int M, int R) {
        int[] help = new int[R - L + 1];
        int i = 0;
        int p1 = L;
        int p2 = M + 1;
        int res = 0;

        while (p1 <= M && p2 <= R) {
            if (arr[p1] > arr[p2]) {
                // 核心:左组从p1到M的所有数都比arr[p2]大
                res += (M - p1 + 1);
                help[i++] = arr[p2++];
            } else { // arr[p1] <= arr[p2]
                help[i++] = arr[p1++];
            }
        }

        // ... 剩余部分和归并排序的merge完全一样 ...
        while (p1 <= M) { help[i++] = arr[p1++]; }
        while (p2 <= R) { help[i++] = arr[p2++]; }
        for (i = 0; i < help.length; i++) { arr[L + i] = help[i]; }

        return res;
    }
}

终极挑战:区间和的个数(LeetCode 327)

问题:给你一个整数数组 nums,一个下限 lower 和一个上限 upper。求有多少个区间和 S(i, j)ij 的和)满足 lower <= S(i, j) <= upper

解法一:暴力 O(N^3)

三层循环。i 循环区间起点,j 循环区间终点,k 循环计算 ij 的和。然后判断是否在 [lower, upper] 范围内。太慢了。

解法二:前缀和优化 O(N^2)

我们可以预处理一个前缀和数组 preSumpreSum[k] 表示 nums[0...k-1] 的和。 那么区间和 S(i, j) 就可以在 O(1) 内算出:preSum[j+1] - preSum[i]。 这样,两层循环 ij,就能算出所有区间和,总复杂度 O(N^2)。在很多场合下还是会超时。

public int countRangeSumN2(int[] nums, int lower, int upper) {
    int n = nums.length;
    long[] preSum = new long[n + 1];
    for (int i = 0; i < n; i++) {
        preSum[i + 1] = preSum[i] + nums[i];
    }

    int count = 0;
    for (int i = 0; i < n; i++) {
        for (int j = i; j < n; j++) {
            long sum = preSum[j + 1] - preSum[i];
            if (sum >= lower && sum <= upper) {
                count++;
            }
        }
    }
    return count;
}

解法三:归并分治 O(N*logN)

O(N^2) 的瓶颈在于,对于每个 j,我们都要回头去遍历所有的 i。这和前面两个问题的暴力解法何其相似。 分治思想是不是又能派上用场了?

我们的问题是要求满足 lower <= preSum[j+1] - preSum[i] <= upper(i, j) 对数。 变形一下,对于每个 j,我们要找有多少个 i <= j,使得 preSum[i] 满足: preSum[j+1] - upper <= preSum[i] <= preSum[j+1] - lower

现在,问题变成了:对于 preSum 数组中的每个数 preSum[k],它前面有多少个数落在 [preSum[k] - upper, preSum[k] - lower] 这个范围里。

这不就是前面问题的翻版吗?"右边的数"看"左边的数"。我们立刻想到对 preSum 数组进行归并排序。

merge(preSum, L, M, R) 的过程中,对于右组的每一个数 preSum[p2]p2M+1R),我们要在左组 [L...M] 中,找到有多少个数 preSum[p1] 满足 preSum[p2] - upper <= preSum[p1] <= preSum[p2] - lower

因为左组 [L...M] 是有序的,所以对于固定的 preSum[p2],我们可以在左组里用二分查找,或者更优的办法——滑动窗口来找到满足条件的 p1 的范围。

具体操作: 对于右组的 preSum[p2]

  1. 在左组中,找到第一个大于等于 preSum[p2] - upper 的位置,记为 windowL

  2. 在左组中,找到第一个大于 preSum[p2] - lower 的位置,记为 windowR

  3. 那么,左组中满足条件的数的范围就是 [windowL, windowR),总共有 windowR - windowL 个。

  4. 把这个数量累加到结果中。

最妙的是,当 p2M+1R 移动时,preSum[p2] 是递增的(因为右组部分有序),所以 preSum[p2] - upperpreSum[p2] - lower 也是递增的。这意味着 windowLwindowR 在左组里也只会单调向右移动!

所以,在 merge 的过程中,我们只需要一个 for 循环遍历右组的所有数,内部的 windowLwindowR 指针不需要回溯,整个统计过程的复杂度是 O(N),而不是 O(N*logN)。

因此,merge 函数总复杂度是 O(N)(统计O(N) + 排序合并O(N)),递归总复杂度还是 O(N*logN)。

public class CountRangeSum {

    public int countRangeSum(int[] nums, int lower, int upper) {
        if (nums == null || nums.length == 0) {
            return 0;
        }
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; i++) {
            preSum[i + 1] = preSum[i] + nums[i];
        }
        return process(preSum, 0, preSum.length - 1, lower, upper);
    }

    private int process(long[] sum, int L, int R, int lower, int upper) {
        if (L == R) {
            return 0;
        }
        int M = L + ((R - L) >> 1);
        return process(sum, L, M, lower, upper)
                + process(sum, M + 1, R, lower, upper)
                + merge(sum, L, M, R, lower, upper);
    }

    private int merge(long[] sum, int L, int M, int R, int lower, int upper) {
        int count = 0;
        // [windowL, windowR) 是满足条件的左组范围
        int windowL = L;
        int windowR = L;

        // 对右组的每个数 sum[p2],在左组中找符合条件的
        for (int p2 = M + 1; p2 <= R; p2++) {
            long min = sum[p2] - upper;
            long max = sum[p2] - lower;
            while (windowL <= M && sum[windowL] < min) {
                windowL++;
            }
            while (windowR <= M && sum[windowR] <= max) {
                windowR++;
            }
            count += (windowR - windowL);
        }

        // 正常的归并排序合并流程
        long[] help = new long[R - L + 1];
        int i = 0;
        int p1 = L;
        int p2 = M + 1;
        while (p1 <= M && p2 <= R) {
            help[i++] = sum[p1] <= sum[p2] ? sum[p1++] : sum[p2++];
        }
        while (p1 <= M) { help[i++] = sum[p1++]; }
        while (p2 <= R) { help[i++] = sum[p2++]; }
        for (i = 0; i < help.length; i++) { sum[L + i] = help[i]; }

        return count;
    }
}

总结

从归并排序,到小和,到逆序对,再到区间和,我们看到的是同一个“爹”生出来的不同“娃”。

分治的套路,特别是基于归并排序的套路,它的核心不在于排序本身,而在于 merge 这个环节提供了一个上帝视角。在这个环节,我们手握两个有序的子集合,可以非常高效地处理那些需要“左边”和“右边”数据进行比较和统计的问题。

下次再遇到一个问题,如果它的暴力解法是 O(N^2) 的,且涉及到 iji<j)的某种关系,不妨想一想,这个问题能不能被扔进归并排序的 merge 过程里,顺便解决掉。