线段树-树形结构
线段树,顾名思义,就是用来处理区间或线段问题的树形结构。它能高效地解决一类问题:对一个数组进行频繁的区间查询和元素更新。如果你看到一个问题,反复地问你“从L到R这个范围内的xx是什么?”或者“把L到R这个范围内的数都改成xx”,那线札树可能就是你的答案。
它的核心思想就是“分治”。把一个大区间,一分为二,变成两个小区间,直到每个区间只包含一个元素。这样,一个长度为N的数组,就对应了一棵满二叉树(或近似满二叉树),树的深度是 O(logN)。每一次查询和更新,我们都只需要沿着树的一条路径走到底,所以复杂度也是 O(logN)。
废话不多说,我们从简单到复杂,一步步看题,把线段树这个工具彻底玩明白。
12.2 简单线段树的应用算法设计
我们先从最经典的区间求和问题入手,来感受一下线段树的构建和使用。
12.2.1 区域和检索(数组不可变)
题目 (LeetCode 303):
给定一个整数数组
nums,处理以下类型的多个查询:
- 计算索引
left和right(包含left和right)之间的nums元素的 和 ,其中left <= right。实现
NumArray类:
NumArray(int[] nums)使用数组nums初始化对象
int sumRange(int i, int j)返回数组nums中索引left和right之间的元素的 总和 ,包含left和right两点。
思路分析
看到这个题,我的第一反应是,这太简单了。
解法一:暴力法
最直接的想法,每次调用 sumRange(i, j) 时,我就写个 for 循环,从 i 遍历到 j,把路过的数全加起来。
构造函数:O(1),存一下数组就行。
sumRange:O(N),因为j-i最坏情况是N-1。
如果查询次数特别多,这个方法肯定会超时。我们需要优化 sumRange 的效率。
解法二:前缀和数组
这是一个典型的“空间换时间”的优化。既然每次都重新算很慢,那我能不能提前算好一些东西存起来?
可以!我们可以创建一个“前缀和数组” preSum。preSum[i] 存的是 nums[0...i-1] 的和。
preSum[0] = 0preSum[1] = nums[0]preSum[2] = nums[0] + nums[1]...
preSum[i] = nums[0] + ... + nums[i-1]
这样,当我要计算 sumRange(left, right),也就是 nums[left] + ... + nums[right] 的和时,它就等于 (nums[0] + ... + nums[right]) - (nums[0] + ... + nums[left-1])。
正好对应 preSum[right+1] - preSum[left]。
构造函数:O(N) 的时间去构建
preSum数组。sumRange:O(1) 的时间,做一次减法。
对于这道题,数组是“不可变”的,所以前缀和数组是完美的最优解。我们甚至不需要线段树。但是,为了学习线段树,我们来思考一个问题:如果数组可变呢? 如果一个 update(index, val) 操作把 nums[index] 改了,preSum 数组从 preSum[index+1] 开始就全错了,需要 O(N) 的时间去修正。这就引出了线段树的用武之地。
解法三:线段树 (为后续题目做铺垫)
虽然这题用线段树有点杀鸡用牛刀,但它是理解线段树工作原理的绝佳起点。
线段树的每个节点都代表一个区间 [L, R] 的和。
根节点代表整个数组
[0, N-1]的和。它的左孩子代表
[0, (N-1)/2]的和,右孩子代表[(N-1)/2 + 1, N-1]的和。依次类推,直到叶子节点,每个叶子节点代表一个单独的元素
nums[i]。
我们需要三个核心操作:
build(arr, L, R):构建树。递归地把[L, R]分成两半,先构建左右子树,然后pushUp操作,父节点的值等于左右孩子节点的值之和。update(index, val, L, R):单点更新。从根节点出发,根据index属于左半边还是右半边,递归地往下找,直到找到对应的叶子节点。更新叶子节点后,再一路pushUp更新所有祖先节点。query(queryL, queryR, L, R):区间查询。从根节点出发,看当前节点代表的区间[L, R]和要查询的区间[queryL, queryR]的关系。如果
[L, R]完全被[queryL, queryR]覆盖,直接返回当前节点存的和。否则,看查询区间和左右子区间是否有交集,有交集就递归地去子树里查,最后把结果加起来。
对于这道不可变数组的题,我们只需要 build 和 query。
代码实现 (前缀和)
class NumArray {
private int[] preSum;
public NumArray(int[] nums) {
if (nums == null || nums.length == 0) {
return;
}
preSum = new int[nums.length + 1];
for (int i = 0; i < nums.length; i++) {
preSum[i + 1] = preSum[i] + nums[i];
}
}
public int sumRange(int left, int right) {
return preSum[right + 1] - preSum[left];
}
}
12.2.2 二维区域和检索(可改)
这题我们直接上一个有难度的版本,结合 "mutable" 和 "2D",来看看线段树的威力。
题目 (LeetCode 308):
给你一个 2D 矩阵
matrix,处理以下类型的多个查询:
更新
matrix中(row, col)的单元格的值为val。计算以
(row1, col1)为左上角,(row2, col2)为右下角的子矩阵的元素总和。实现
NumMatrix类:
NumMatrix(int[][] matrix)初始化对象。
void update(int row, int col, int val)更新(row, col)的值。
int sumRegion(int row1, int col1, int row2, int col2)返回子矩阵的和。
思路分析
解法一:暴力法
老规矩,先想暴力怎么做。
update:O(1),直接修改matrix[row][col]。sumRegion:O(M*N),两层 for 循环,遍历子矩阵求和。
查询一多,就崩了。
解法二:二维前缀和
仿照一维的做法,我们也可以构建二维前缀和。定义 preSum[i][j] 为以 (0,0) 为左上角,(i-1, j-1) 为右下角的矩阵的和。 那么,sumRegion(r1, c1, r2, c2) 的和可以通过 preSum 数组加加减减在 O(1) 内算出。
公式是 sum = preSum[r2+1][c2+1] - preSum[r1][c2+1] - preSum[r2+1][c1] + preSum[r1][c1]。
问题出在哪里?出在 update。 如果 matrix[row][col] 变了,那么所有 preSum[i][j] 其中 i > row 且 j > col 的值全都要跟着变。这个更新的代价是 O(M*N)。
查询快,更新慢。暴力法是更新快,查询慢。两者都无法同时处理频繁的更新和查询。这时候,主角该登场了。
解法三:二维线段树(树状数组更简单,但我们今天主角是线段树)
如何把线段树扩展到二维? 一个非常自然的想法是“降维打击”或者叫“树套树”。
我们可以先对“行”建立一棵线段树。这棵线段树的每个节点,不再是一个简单的数值,而是另一棵线段树。
最外层的线段树(我们叫它“行树”)的节点
node_row代表行的区间[row_start, row_end]。node_row里面存的,是一棵新的线段树(我们叫它“列树”)。这棵“列树”维护的是matrix在[row_start, row_end]这些行,和[0, N-1]这些列构成的子矩阵的信息。具体来说,这棵列树的每个节点node_col代表列区间[col_start, col_end]的和。这个和是matrix中所有(r, c)的总和,其中r在[row_start, row_end]之间,c在[col_start, col_end]之间。
听起来有点绕,我们来梳理一下操作:
update(row, col, val):我们要在“行树”上,找到所有包含
row的区间。这对应着从根节点到叶子节点的一条路径,复杂度 O(logM)。对于这条路径上的每一个节点,我们都要进入它内部的“列树”,在列树上执行一次
update(col, diff)操作,diff是新旧值的差。列树的更新也是 O(logN)。总复杂度:O(logM * logN)。
sumRegion(r1, c1, r2, c2):首先在“行树”上查询,找到代表行区间
[r1, r2]的若干个节点。这个查询是 O(logM)。对于每一个找到的行树节点,我们进入它内部的“列树”,在列树上执行一次
query(c1, c2)。这个查询是 O(logN)。把所有从列树查询到的结果加起来,就是最终答案。
总复杂度:O(logM * logN)。
这个时空复杂度就非常理想了。实现起来会有点复杂,因为是树套树,但逻辑是通的。
代码实现 (二维树状数组 - 思路一致且代码更简洁)
为了代码的简洁性,我这里用二维树状数组来实现。它的思想和二维线段树一样,都是用 O(logM * logN) 的代价完成更新和查询,但代码量小得多。
class NumMatrix {
private int[][] matrix;
private int[][] tree; // 树状数组
private int rows;
private int cols;
public NumMatrix(int[][] matrix) {
if (matrix.length == 0 || matrix[0].length == 0) return;
this.rows = matrix.length;
this.cols = matrix[0].length;
this.matrix = new int[rows][cols];
this.tree = new int[rows + 1][cols + 1];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
update(i, j, matrix[i][j]);
}
}
}
// 更新 (row, col) 的值为 val
public void update(int row, int col, int val) {
int diff = val - matrix[row][col];
matrix[row][col] = val;
for (int i = row + 1; i <= rows; i += i & -i) {
for (int j = col + 1; j <= cols; j += j & -j) {
tree[i][j] += diff;
}
}
}
// 查询 (0,0)到(row, col)的区域和
private int query(int row, int col) {
int sum = 0;
for (int i = row + 1; i > 0; i -= i & -i) {
for (int j = col + 1; j > 0; j -= j & -j) {
sum += tree[i][j];
}
}
return sum;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
return query(row2, col2) - query(row1 - 1, col2) - query(row2, col1 - 1) + query(row1 - 1, col1 - 1);
}
}
12.3 复杂线段树的应用算法设计
简单的单点更新和区间查询我们已经掌握了。现在来上强度,看看涉及“区间更新”的场景,这就需要引入线段树的精髓——懒加载 (Lazy Propagation)。
懒加载的思想是:如果一次更新操作覆盖了某个节点代表的整个区间,我们没必要继续递归下去更新它所有的子孙节点。我们可以在这个节点上打一个“懒标记”,表示“这个节点下面的所有子孙都已经被更新过了,但具体值我先不往下传”。等到将来有查询操作需要进入这个节点的子树时,我们再把这个“懒标记”下推给它的直接孩子,这个过程叫 pushDown。
12.3.1 Range模块
题目 (LeetCode 715):
RangeModule是一种跟踪数字范围的模块。设计一个数据结构来跟踪表示为 半开区间[left, right)的范围集合。 实现RangeModule类:
RangeModule()初始化对象。
void addRange(int left, int right)添加半开区间[left, right),跟踪该区间中的每个实数。添加与当前跟踪的数字部分重叠的区间时,会将重叠部分合并。
boolean queryRange(int left, int right)只有在当前正在跟踪的范围集合中,每个实数都在[left, right)中时,才返回true。
void removeRange(int left, int right)停止跟踪[left, right)中当前正在跟踪的每个实数。
思路分析
坐标范围很大 (1~10^9),但操作次数不多。这通常是“离散化”的信号,但我们先不考虑离散化,看看能否用线段树直接处理。这种动态开点的线段树,可以处理大范围的坐标。
解法一:TreeMap/TreeSet 管理区间
一个比较直观的做法是,用一个 TreeMap<Integer, Integer> 来存储所有不相交的区间,key是区间的左端点,value是右端点。
addRange: 找到所有与[left, right)相交的区间,然后把它们和新区间合并成一个大区间。这需要复杂的查找和删除操作。removeRange: 找到所有与[left, right)相交的区间,可能需要把一个已有区间分裂成两半。queryRange: 找到第一个起点≤ left的区间,然后检查它是否能完全覆盖[left, right)。
这个方法逻辑非常复杂,边界情况很多,容易写错。
解法二:线段树 + 懒加载
我们可以把整个 [1, 10^9] 的数轴看作一个巨大的数组。线段树的每个节点 [L,R] 维护一个状态:这个区间是否被“完全覆盖”。
节点状态:
cover(boolean),表示该节点对应的区间是否被完整跟踪。懒标记:
lazy(int),1表示这个区间及其子区间都要被设置为覆盖,-1表示都要被设置为不覆盖,0表示没有懒操作。
核心操作:
pushUp(node): 如果一个节点的左右孩子都被完全覆盖了,那么这个节点也被完全覆盖了。node.cover = node.left.cover && node.right.cover。pushDown(node): 在访问一个节点的子节点之前,先检查该节点有没有懒标记。如果有,就把懒标记的信息传递给左右孩子,并更新孩子节点的cover状态,然后清除自己的懒标记。update(L, R, val): (这是addRange和removeRange的底层实现)。addRange(l, r)调用update(l, r, 1)removeRange(l, r)调用update(l, r, -1)update函数是个标准的区间更新模板:如果当前节点[nodeL, nodeR]完全被[L, R]覆盖,就直接更新当前节点的cover状态,并打上懒标记,然后返回。否则,先pushDown,再根据区间重叠情况递归进入左右子树,最后pushUp。
query(L, R): 标准的区间查询模板。如果当前节点[nodeL, nodeR]被[L, R]完全覆盖,直接返回自己的cover状态。否则,先pushDown,再递归查询,最后把左右子树的结果合并(这里是做&&操作)。
由于坐标范围是 10^9,我们不可能真的开这么大的数组。这里要用动态开点线段树。也就是说,我们不提前把整棵树建好,而是在需要访问一个节点时,如果它不存在,再动态地创建它。
代码实现 (动态开点线段树)
class RangeModule {
private Node root;
public RangeModule() {
// [1, 10^9] 是一个非常大的范围
root = new Node(1, 1_000_000_000);
}
public void addRange(int left, int right) {
update(root, left, right, true);
}
public boolean queryRange(int left, int right) {
return query(root, left, right);
}
public void removeRange(int left, int right) {
update(root, left, right, false);
}
// 节点定义
private class Node {
int l, r;
boolean covered;
// lazy tag: 0 - no change, 1 - set to covered, -1 - set to uncovered
int lazy;
Node leftChild, rightChild;
Node(int l, int r) {
this.l = l;
this.r = r;
}
}
private void pushDown(Node node) {
if (node.lazy == 0) return;
int mid = node.l + (node.r - node.l) / 2;
if (node.leftChild == null) node.leftChild = new Node(node.l, mid);
if (node.rightChild == null) node.rightChild = new Node(mid + 1, node.r);
node.leftChild.covered = (node.lazy == 1);
node.rightChild.covered = (node.lazy == 1);
node.leftChild.lazy = node.lazy;
node.rightChild.lazy = node.lazy;
node.lazy = 0;
}
private void pushUp(Node node) {
node.covered = node.leftChild != null && node.leftChild.covered &&
node.rightChild != null && node.rightChild.covered;
}
private void update(Node node, int L, int R, boolean state) {
// [L, R) 对应到线段树的闭区间 [L, R-1]
R--;
if (node.l >= L && node.r <= R) {
node.covered = state;
node.lazy = state ? 1 : -1;
return;
}
pushDown(node);
int mid = node.l + (node.r - node.l) / 2;
if (L <= mid) {
if (node.leftChild == null) node.leftChild = new Node(node.l, mid);
update(node.leftChild, L, R, state);
}
if (R > mid) {
if (node.rightChild == null) node.rightChild = new Node(mid + 1, node.r);
update(node.rightChild, L, R, state);
}
pushUp(node);
}
private boolean query(Node node, int L, int R) {
// [L, R) 对应到线段树的闭区间 [L, R-1]
R--;
if (node.l >= L && node.r <= R) {
return node.covered;
}
pushDown(node);
boolean result = true;
int mid = node.l + (node.r - node.l) / 2;
if (L <= mid) {
if (node.leftChild == null) return false; // 区间未被完全覆盖
result = result && query(node.leftChild, L, R);
}
if (R > mid) {
if (node.rightChild == null) return false;
result = result && query(node.rightChild, L, R);
}
return result;
}
}
12.4 离散化在线段树中的应用
上面那题我们用动态开点解决了坐标范围过大的问题。还有一种更常用的技术是离散化。当问题的坐标很大,但我们只关心这些坐标点的相对大小和它们构成的区间时,就可以用离散化把这些稀疏的大坐标映射到 [0, k-1] 的紧凑区间上。
12.4.1 区间和的个数
题目(LeetCode 327):
给你一个整数数组
nums以及两个整数lower和upper。求数组中,值位于范围[lower, upper](包含lower和upper)之间的区间和的个数。 区间和S(i, j)表示在nums中第i个元素到第j个元素的和(包含i和j),其中i ≤ j。
思路分析
解法一:暴力法
先求出前缀和数组 preSum。S(i, j) = preSum[j+1] - preSum[i]。 然后两层 for 循环遍历所有的 (i, j) 组合,判断 S(i, j) 是否在 [lower, upper] 范围内。 复杂度 O(N^2),肯定超时。
解法二:归并排序 / 平衡树
这是另一条路,利用归并排序的分治过程来统计符合条件的数对,复杂度 O(NlogN)。这个思路也很巧妙,但我们今天聚焦线段树。
解法三:离散化 + 线段树/树状数组
我们把暴力法的 O(N^2) 过程优化一下。 lower <= preSum[j] - preSum[i] <= upper (方便起见,preSum[i] 代表 nums[0..i-1] 的和) 变形得到 preSum[j] - upper <= preSum[i] <= preSum[j] - lower
算法流程就变成了: 我们从左到右遍历 j,对于每一个 preSum[j],我们需要统计在它之前的 i < j 中,有多少个 preSum[i] 满足上面的不等式。
这是一个动态的“查询某个范围内数字个数”的问题。
preSum的值域可能很大,不能直接当数组下标。我们需要一个数据结构,能快速地:
增加一个数。
查询某个范围
[a,b]内有多少个数。
线段树就是干这个的!但 preSum 的值域问题怎么解决?离散化。
离散化步骤:
遍历
nums,计算出所有的前缀和s[j]。对于每个
s[j],我们关心的查询范围是[s[j]-upper, s[j]-lower]。所以,我们需要对所有可能出现的s[j],s[j]-lower,s[j]-upper这些值进行离散化。把这些值收集起来,去重,排序。得到一个有序的、无重复的值列表
allNumbers。现在,任何一个原始值
v都可以用它在allNumbers中的下标来代替,这就是它的“排名”或“离散化后的值”。
线段树/树状数组步骤:
建立一个线段树(或树状数组),长度为离散化后不同值的个数
k。线段树的每个节点存的是对应值域范围内的数的个数。初始时,为了处理
preSum[i](i=-1) 等于0的情况,我们先将离散化后0的位置加1。遍历前缀和
s[j](从 j=0 到 n-1): a. 确定查询范围[s[j]-upper, s[j]-lower]。 b. 找到这两个值离散化后的排名rank_lower和rank_upper。 c. 在线段树中查询排名在[rank_lower, rank_upper]范围内的数字个数,累加到最终结果ans。 d. 把s[j]离散化后的排名rank_s_j对应在线段树的位置加1,表示我们又见到了一个s[j]。
这样,每次遍历的查询和更新都是 O(logK),总复杂度是 O(NlogN)(其中K最大是3N)。
代码实现 (离散化 + 树状数组)
public class CountRangeSum {
public int countRangeSum(int[] nums, int lower, int upper) {
long sum = 0;
long[] preSum = new long[nums.length + 1];
for (int i = 0; i < nums.length; i++) {
sum += nums[i];
preSum[i + 1] = sum;
}
// 离散化
Set<Long> allNumbersSet = new HashSet<>();
for (long s : preSum) {
allNumbersSet.add(s);
allNumbersSet.add(s - lower);
allNumbersSet.add(s - upper);
}
List<Long> allNumbers = new ArrayList<>(allNumbersSet);
Collections.sort(allNumbers);
Map<Long, Integer> map = new HashMap<>();
for (int i = 0; i < allNumbers.size(); i++) {
map.put(allNumbers.get(i), i + 1); // 树状数组下标从1开始
}
// 树状数组
FenwickTree tree = new FenwickTree(allNumbers.size());
int count = 0;
for (long s : preSum) {
int left = map.get(s - upper);
int right = map.get(s - lower);
count += tree.query(right) - tree.query(left - 1);
tree.update(map.get(s), 1);
}
return count;
}
private class FenwickTree {
private int[] tree;
private int size;
public FenwickTree(int n) {
this.size = n;
this.tree = new int[n + 1];
}
public void update(int index, int delta) {
while (index <= size) {
tree[index] += delta;
index += index & (-index);
}
}
public int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index -= index & (-index);
}
return sum;
}
}
}
12.4.2 计算右侧小于当前元素的个数
题目(LeetCode 315):
给定一个整数数组
nums,按要求返回一个新数组counts。数组counts有该性质:counts[i]的值是nums[i]右侧小于nums[i]的元素的数量。
思路分析
解法一:暴力法 对于每个 nums[i],都向右遍历 j from i+1 to n-1,统计 nums[j] < nums[i] 的个数。 O(N^2) 复杂度,会超时。
解法二:归并排序 这道题是归并排序求“逆序对”的经典变种。在 merge 两个有序子数组 left 和 right 时,如果 left[p1] > right[p2],那么 left[p1] 以及 left 数组中在 p1 右边的所有元素,都比 right[p2] 大。我们就可以一次性统计出很多逆序对。
解法三:离散化 + 线段树/树状数组
这个问题可以换一个角度看:从右往左遍历数组。 当我们处理 nums[i] 时,我们需要查询在 i 右边,也就是 nums[i+1...n-1] 中,有多少个数比 nums[i] 小。 这和上一题的思路非常像!
流程:
离散化:
nums数组中的值可能很大,负数,不连续。所以第一步,对nums中所有出现过的数进行离散化,将它们映射到[1, k]这个紧凑的整数范围。建立数据结构:建立一个大小为
k的线段树或树状数组,用来统计每个(离散化后的)值出现了多少次。从右往左遍历:
for i from n-1 down to 0:获取
nums[i]离散化后的排名rank。去数据结构中查询
[1, rank-1]这个范围内的数字总数。这个总数,就是到目前为止(即在i右侧)我们见过的、比nums[i]小的数的个数。把这个结果存到ans[i]。将
rank对应位置的计数加1,表示我们现在又见到了nums[i]这个数。
这个过程完美地解决了问题,每次查询和更新都是 O(logK),总复杂度 O(NlogN)。
代码实现 (离散化 + 树状数组)
class Solution {
public List<Integer> countSmaller(int[] nums) {
// 离散化
Set<Integer> set = new HashSet<>();
for (int num : nums) {
set.add(num);
}
List<Integer> sortedList = new ArrayList<>(set);
Collections.sort(sortedList);
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < sortedList.size(); i++) {
map.put(sortedList.get(i), i + 1);
}
// 树状数组
FenwickTree tree = new FenwickTree(map.size());
List<Integer> result = new ArrayList<>();
// 从右向左遍历
for (int i = nums.length - 1; i >= 0; i--) {
int rank = map.get(nums[i]);
// 查询 排名 < rank 的数的个数
int count = tree.query(rank - 1);
result.add(count);
// 将 rank 的计数+1
tree.update(rank, 1);
}
Collections.reverse(result);
return result;
}
// 树状数组 (Fenwick Tree) 实现
private class FenwickTree {
private int[] tree;
private int size;
public FenwickTree(int n) {
this.size = n;
this.tree = new int[n + 1];
}
public void update(int index, int delta) {
while (index <= size) {
tree[index] += delta;
index += index & (-index);
}
}
public int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index -= index & (-index);
}
return sum;
}
}
}