区间dp从零开始
什么是区间DP?简单来说,就是一个问题,它能被拆分成一个个更小的、连续的子区间问题,并且最终大区间的最优解可以由小区间的最优解推导出来。它的典型特征就是,我们求解dp[i][j](表示区间[i, j]上的最优解)时,需要依赖所有比[i, j]更短的子区间的最-优解,比如dp[i][k]和dp[k+1][j]。
听着还是有点抽象?没关系,我们直接上题,用题目来撕开它的神秘面纱。
经典例题:石子合并
在一条直线上有N堆石子,每堆石子的重量已知。现在需要将所有石子合并成一堆,每次只能合并相邻的两堆石子,合并的代价为这两堆石子的重量之和。求将所有石子合并成一堆的最小总代价。
比如,有4堆石子,重量分别为 [3, 2, 4, 1]。
如果我们先合并
3和2(代价3+2=5),石子堆变为[5, 4, 1]。再合并
5和4(代价5+4=9),石子堆变为[9, 1]。最后合并
9和1(代价9+1=10),石子堆变为[10]。总代价 = 5 + 9 + 10 = 24。
但这不是唯一的合并方式。
如果先合并
2和4(代价6),石子堆变为[3, 6, 1]。再合并
6和1(代价7),石子堆变为[3, 7]。最后合并
3和7(代价10),石子堆变为[10]。总代价 = 6 + 7 + 10 = 23。
显然,不同的合并顺序会导致不同的总代价。我们的目标就是找到这个最小的总代价。
解法一:暴力递归(从顶向下)
拿到一个算法题,如果你一点思路都没有,怎么办?别慌,先想暴力解法怎么搞。对于这道题,最暴力的方式就是把所有可能的合并方式都试一遍。
我们来定义一个函数 f(i, j),它的功能是返回将arr[i...j]范围内的所有石子合并成一堆的最小代价。我们的最终目标就是求 f(0, N-1)。
好,那怎么求f(i, j)呢?
我们想,[i...j]这一整个大区间,不管中间怎么合,它必然有最后一次合并。这次合并一定是把某个已经合并好的[i...k]部分和另一个已经合并好的[k+1...j]部分合在一起。这里的k就是所谓的“分割点”。
k可以从i一直到j-1。
如果
k=i,就是[i]和[i+1...j]合并。如果
k=i+1,就是[i...i+1]和[i+2...j]合并。...
如果
k=j-1,就是[i...j-1]和[j]合并。
对于任意一个分割点k,将[i...k]合并的最小代价是f(i, k),将[k+1...j]合并的最小代价是f(k+1, j)。当这两部分最终合并时,产生的代价是[i...k]的总重量加上[k+1...j]的总重量,其实就是[i...j]区间的总重量。
所以,我们可以得到一个递推关系: f(i, j) = min( f(i, k) + f(k+1, j) ) + sum(i, j),其中 i <= k < j。
我们把所有可能的k都试一遍,取其中的最小值,就是f(i, j)的答案。
那递归的 base case (出口) 是什么? 当i == j时,区间里只有一堆石子,不需要合并,所以代价是0。即 f(i, i) = 0。
为了快速计算sum(i, j),我们可以预处理一个前缀和数组prefixSum。sum(i, j) 就等于 prefixSum[j+1] - prefixSum[i]。
好了,思路有了,直接上代码。
public class MergeStones {
// 解法一:暴力递归
public static int minCost1(int[] stones) {
if (stones == null || stones.length < 2) {
return 0;
}
int n = stones.length;
// 预处理前缀和数组
int[] prefixSum = new int[n + 1];
for (int i = 0; i < n; i++) {
prefixSum[i + 1] = prefixSum[i] + stones[i];
}
return process(0, n - 1, prefixSum);
}
// process(i, j) 返回 arr[i...j] 范围上合并的最小代价
private static int process(int i, int j, int[] prefixSum) {
// base case: 范围内只有一堆石子,无需合并
if (i == j) {
return 0;
}
int min = Integer.MAX_VALUE;
// [i...j]范围上石子的总重量
int sum = prefixSum[j + 1] - prefixSum[i];
// 尝试所有可能的分割点 k
// 左部分: [i...k], 右部分: [k+1...j]
for (int k = i; k < j; k++) {
int cost = process(i, k, prefixSum) + process(k + 1, j, prefixSum);
min = Math.min(min, cost);
}
return min + sum;
}
}
这个暴力递归有什么问题? 老问题了,大量的重复计算。比如在计算f(0, 5)时,我们可能会尝试分割点k=2,需要求f(0, 2)和f(3, 5)。而在计算f(0, 6)时,尝试分割点k=2,又需要求一遍f(0, 2)。这就造成了巨大的浪费。
解法二:记忆化搜索(自顶向下DP)
怎么优化暴力递归?加缓存!这就是记忆化搜索,也是一种动态规划的实现方式。
我们搞一个二维数组dp[N][N],dp[i][j]用来存f(i, j)的计算结果。初始时,dp表里所有值都设为一个特殊值(比如-1),表示没算过。
每次进入process(i, j)函数时:
先查
dp[i][j]是不是算过了。如果算过了(值不是-1),直接返回
dp[i][j]的值。如果没算过,就老老实实地走暴力递归的计算逻辑。
算出来结果后,在
return之前,把结果存到dp[i][j]里。
这样,每个子问题f(i, j)就只会被计算一次。
我们来改造一下代码:
public class MergeStones {
// 解法二:记忆化搜索
public static int minCost2(int[] stones) {
if (stones == null || stones.length < 2) {
return 0;
}
int n = stones.length;
int[] prefixSum = new int[n + 1];
for (int i = 0; i < n; i++) {
prefixSum[i + 1] = prefixSum[i] + stones[i];
}
// dp[i][j] 缓存 process(i, j) 的结果
int[][] dp = new int[n][n];
return process(0, n - 1, prefixSum, dp);
}
private static int process(int i, int j, int[] prefixSum, int[][] dp) {
// 如果缓存命中,直接返回
if (dp[i][j] != 0) {
return dp[i][j];
}
if (i == j) {
return 0; // base case 不需要存,因为dp默认是0
}
int min = Integer.MAX_VALUE;
int sum = prefixSum[j + 1] - prefixSum[i];
for (int k = i; k < j; k++) {
int cost = process(i, k, prefixSum, dp) + process(k + 1, j, prefixSum, dp);
min = Math.min(min, cost);
}
// 算出的结果存入缓存
dp[i][j] = min + sum;
return dp[i][j];
}
}
这种写法,就是典型的“自顶向下”的动态规划。它和暴力递归的结构几乎一模一样,只是加了个“傻瓜缓存”,但效率天差地别。时间复杂度分析一下:i和j的组合有 O(N^2) 种,每个状态dp[i][j]内部有一个 O(N) 的循环,所以总的时间复杂度是 O(N^3)。
解法三:经典DP(自底向上)
有了上面的分析,要写出纯粹的、自底向上的DP版本就很容易了。这也是我们常说的“标准”区间DP形态。
递归是“从上往下”求解,遇到子问题再解决子问题。而迭代DP是“从下往上”,我们先把小区间的问题都解决了,再用它们来解决大区间的问题。
定义DP数组:
dp[i][j]的含义不变,还是代表合并arr[i...j]的最小代价。初始化:
dp表对角线dp[i][i]都为0,因为单个石子堆不需要合并。确定迭代顺序:这是区间DP最关键的一步。我们发现,计算
dp[i][j]时,需要用到所有比它短的子区间的dp值。所以,我们应该按照区间长度从小到大的顺序来填表。区间长度
L从2开始,一直到N。当
L固定时,我们遍历所有长度为L的区间。区间的起始点i可以从0到N-L。终点
j自然就是i + L - 1。然后,我们再用内层循环遍历分割点
k,从i到j-1,套用状态转移方程。
整个dp表的填充顺序,是从对角线开始,一层一层向右上角扩展,最后dp[0][N-1]就是我们想要的答案。
看代码实现:
public class MergeStones {
// 解法三:经典动态规划
public static int minCost3(int[] stones) {
if (stones == null || stones.length < 2) {
return 0;
}
int n = stones.length;
int[] prefixSum = new int[n + 1];
for (int i = 0; i < n; i++) {
prefixSum[i + 1] = prefixSum[i] + stones[i];
}
int[][] dp = new int[n][n];
// L 是区间长度
for (int L = 2; L <= n; L++) {
// i 是区间左端点
for (int i = 0; i <= n - L; i++) {
// j 是区间右端点
int j = i + L - 1;
dp[i][j] = Integer.MAX_VALUE;
int sum = prefixSum[j + 1] - prefixSum[i];
// k 是分割点
for (int k = i; k < j; k++) {
dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k + 1][j]);
}
dp[i][j] += sum;
}
}
return dp[0][n - 1];
}
}
这个版本和记忆化搜索的本质是一样的,时间复杂度也是 O(N^3),但它是纯迭代的,没有递归开销,在某些情况下可能会更快一点。这三种解法,体现了解决一个问题从暴力尝试到优化,再到标准化DP的完整思考链路。
总结一下
区间DP的套路感非常强,基本就是三层循环:
第一层循环:枚举区间长度
L。第二层循环:枚举区间起始点
i(或终点j)。第三层循环:枚举区间内的分割点
k。
状态转移方程也很有辨识度: dp[i][j] = min/max (dp[i][k] + dp[k+1][j] + ...)