区间dp从零开始

MrHe··3 min read

什么是区间DP?简单来说,就是一个问题,它能被拆分成一个个更小的、连续的子区间问题,并且最终大区间的最优解可以由小区间的最优解推导出来。它的典型特征就是,我们求解dp[i][j](表示区间[i, j]上的最优解)时,需要依赖所有比[i, j]更短的子区间的最-优解,比如dp[i][k]dp[k+1][j]

听着还是有点抽象?没关系,我们直接上题,用题目来撕开它的神秘面纱。

经典例题:石子合并

在一条直线上有N堆石子,每堆石子的重量已知。现在需要将所有石子合并成一堆,每次只能合并相邻的两堆石子,合并的代价为这两堆石子的重量之和。求将所有石子合并成一堆的最小总代价。

比如,有4堆石子,重量分别为 [3, 2, 4, 1]

  • 如果我们先合并32(代价3+2=5),石子堆变为 [5, 4, 1]

  • 再合并54(代价5+4=9),石子堆变为 [9, 1]

  • 最后合并91(代价9+1=10),石子堆变为 [10]

  • 总代价 = 5 + 9 + 10 = 24。

但这不是唯一的合并方式。

  • 如果先合并24(代价6),石子堆变为 [3, 6, 1]

  • 再合并61(代价7),石子堆变为 [3, 7]

  • 最后合并37(代价10),石子堆变为 [10]

  • 总代价 = 6 + 7 + 10 = 23。

显然,不同的合并顺序会导致不同的总代价。我们的目标就是找到这个最小的总代价。


解法一:暴力递归(从顶向下)

拿到一个算法题,如果你一点思路都没有,怎么办?别慌,先想暴力解法怎么搞。对于这道题,最暴力的方式就是把所有可能的合并方式都试一遍。

我们来定义一个函数 f(i, j),它的功能是返回将arr[i...j]范围内的所有石子合并成一堆的最小代价。我们的最终目标就是求 f(0, N-1)

好,那怎么求f(i, j)呢?

我们想,[i...j]这一整个大区间,不管中间怎么合,它必然有最后一次合并。这次合并一定是把某个已经合并好的[i...k]部分和另一个已经合并好的[k+1...j]部分合在一起。这里的k就是所谓的“分割点”。

k可以从i一直到j-1

  • 如果k=i,就是[i][i+1...j]合并。

  • 如果k=i+1,就是[i...i+1][i+2...j]合并。

  • ...

  • 如果k=j-1,就是[i...j-1][j]合并。

对于任意一个分割点k,将[i...k]合并的最小代价是f(i, k),将[k+1...j]合并的最小代价是f(k+1, j)。当这两部分最终合并时,产生的代价是[i...k]的总重量加上[k+1...j]的总重量,其实就是[i...j]区间的总重量。

所以,我们可以得到一个递推关系: f(i, j) = min( f(i, k) + f(k+1, j) ) + sum(i, j),其中 i <= k < j

我们把所有可能的k都试一遍,取其中的最小值,就是f(i, j)的答案。

那递归的 base case (出口) 是什么? 当i == j时,区间里只有一堆石子,不需要合并,所以代价是0。即 f(i, i) = 0

为了快速计算sum(i, j),我们可以预处理一个前缀和数组prefixSumsum(i, j) 就等于 prefixSum[j+1] - prefixSum[i]

好了,思路有了,直接上代码。

public class MergeStones {

    // 解法一:暴力递归
    public static int minCost1(int[] stones) {
        if (stones == null || stones.length < 2) {
            return 0;
        }
        int n = stones.length;
        // 预处理前缀和数组
        int[] prefixSum = new int[n + 1];
        for (int i = 0; i < n; i++) {
            prefixSum[i + 1] = prefixSum[i] + stones[i];
        }

        return process(0, n - 1, prefixSum);
    }

    // process(i, j) 返回 arr[i...j] 范围上合并的最小代价
    private static int process(int i, int j, int[] prefixSum) {
        // base case: 范围内只有一堆石子,无需合并
        if (i == j) {
            return 0;
        }

        int min = Integer.MAX_VALUE;
        // [i...j]范围上石子的总重量
        int sum = prefixSum[j + 1] - prefixSum[i];

        // 尝试所有可能的分割点 k
        // 左部分: [i...k], 右部分: [k+1...j]
        for (int k = i; k < j; k++) {
            int cost = process(i, k, prefixSum) + process(k + 1, j, prefixSum);
            min = Math.min(min, cost);
        }

        return min + sum;
    }
}

这个暴力递归有什么问题? 老问题了,大量的重复计算。比如在计算f(0, 5)时,我们可能会尝试分割点k=2,需要求f(0, 2)f(3, 5)。而在计算f(0, 6)时,尝试分割点k=2,又需要求一遍f(0, 2)。这就造成了巨大的浪费。


解法二:记忆化搜索(自顶向下DP)

怎么优化暴力递归?加缓存!这就是记忆化搜索,也是一种动态规划的实现方式。

我们搞一个二维数组dp[N][N]dp[i][j]用来存f(i, j)的计算结果。初始时,dp表里所有值都设为一个特殊值(比如-1),表示没算过。

每次进入process(i, j)函数时:

  1. 先查dp[i][j]是不是算过了。

  2. 如果算过了(值不是-1),直接返回dp[i][j]的值。

  3. 如果没算过,就老老实实地走暴力递归的计算逻辑。

  4. 算出来结果后,在return之前,把结果存到dp[i][j]里。

这样,每个子问题f(i, j)就只会被计算一次。

我们来改造一下代码:

public class MergeStones {
    // 解法二:记忆化搜索
    public static int minCost2(int[] stones) {
        if (stones == null || stones.length < 2) {
            return 0;
        }
        int n = stones.length;
        int[] prefixSum = new int[n + 1];
        for (int i = 0; i < n; i++) {
            prefixSum[i + 1] = prefixSum[i] + stones[i];
        }

        // dp[i][j] 缓存 process(i, j) 的结果
        int[][] dp = new int[n][n];

        return process(0, n - 1, prefixSum, dp);
    }

    private static int process(int i, int j, int[] prefixSum, int[][] dp) {
        // 如果缓存命中,直接返回
        if (dp[i][j] != 0) {
            return dp[i][j];
        }

        if (i == j) {
            return 0; // base case 不需要存,因为dp默认是0
        }

        int min = Integer.MAX_VALUE;
        int sum = prefixSum[j + 1] - prefixSum[i];

        for (int k = i; k < j; k++) {
            int cost = process(i, k, prefixSum, dp) + process(k + 1, j, prefixSum, dp);
            min = Math.min(min, cost);
        }

        // 算出的结果存入缓存
        dp[i][j] = min + sum;
        return dp[i][j];
    }
}

这种写法,就是典型的“自顶向下”的动态规划。它和暴力递归的结构几乎一模一样,只是加了个“傻瓜缓存”,但效率天差地别。时间复杂度分析一下:ij的组合有 O(N^2) 种,每个状态dp[i][j]内部有一个 O(N) 的循环,所以总的时间复杂度是 O(N^3)。


解法三:经典DP(自底向上)

有了上面的分析,要写出纯粹的、自底向上的DP版本就很容易了。这也是我们常说的“标准”区间DP形态。

递归是“从上往下”求解,遇到子问题再解决子问题。而迭代DP是“从下往上”,我们先把小区间的问题都解决了,再用它们来解决大区间的问题。

  1. 定义DP数组dp[i][j]的含义不变,还是代表合并arr[i...j]的最小代价。

  2. 初始化dp表对角线dp[i][i]都为0,因为单个石子堆不需要合并。

  3. 确定迭代顺序:这是区间DP最关键的一步。我们发现,计算dp[i][j]时,需要用到所有比它短的子区间的dp值。所以,我们应该按照区间长度从小到大的顺序来填表。

    • 区间长度L从2开始,一直到N

    • L固定时,我们遍历所有长度为L的区间。区间的起始点i可以从0到N-L

    • 终点j自然就是 i + L - 1

    • 然后,我们再用内层循环遍历分割点k,从ij-1,套用状态转移方程。

整个dp表的填充顺序,是从对角线开始,一层一层向右上角扩展,最后dp[0][N-1]就是我们想要的答案。

看代码实现:

public class MergeStones {
    // 解法三:经典动态规划
    public static int minCost3(int[] stones) {
        if (stones == null || stones.length < 2) {
            return 0;
        }
        int n = stones.length;
        int[] prefixSum = new int[n + 1];
        for (int i = 0; i < n; i++) {
            prefixSum[i + 1] = prefixSum[i] + stones[i];
        }

        int[][] dp = new int[n][n];

        // L 是区间长度
        for (int L = 2; L <= n; L++) {
            // i 是区间左端点
            for (int i = 0; i <= n - L; i++) {
                // j 是区间右端点
                int j = i + L - 1;

                dp[i][j] = Integer.MAX_VALUE;
                int sum = prefixSum[j + 1] - prefixSum[i];

                // k 是分割点
                for (int k = i; k < j; k++) {
                    dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k + 1][j]);
                }
                dp[i][j] += sum;
            }
        }

        return dp[0][n - 1];
    }
}

这个版本和记忆化搜索的本质是一样的,时间复杂度也是 O(N^3),但它是纯迭代的,没有递归开销,在某些情况下可能会更快一点。这三种解法,体现了解决一个问题从暴力尝试到优化,再到标准化DP的完整思考链路。


总结一下

区间DP的套路感非常强,基本就是三层循环:

  1. 第一层循环:枚举区间长度L

  2. 第二层循环:枚举区间起始点i(或终点j)。

  3. 第三层循环:枚举区间内的分割点k

状态转移方程也很有辨识度: dp[i][j] = min/max (dp[i][k] + dp[k+1][j] + ...)