树状数组( Binary Indexed Tree)
树状数组(Fenwick Tree / Binary Indexed Tree)这个东西。这玩意儿听起来挺唬人,但本质上就是个用数组模拟树形结构的工具,专门用来解决一类问题:频繁的单点更新和区间查询。如果我们用普通数组,更新是 O(1),但查区间和就是 O(N);用前缀和数组,查区间和是 O(1),但你更新一个点,后面所有的前缀和都得改,又是 O(N)。两边都不讨好。树状数组就把这两个操作都干到了 O(logN),一下就平衡了。
废话不多说,咱们从题入手,看看这东西到底怎么用,思路是怎么一步步升级的。
计算右侧小于当前元素的个数
问题描述:
给定一个整数数组
nums,按要求返回一个新数组counts。数组counts有该性质:counts[i]的值是nums[i]右侧小于nums[i]的元素的数量。示例:
输入:
nums = [5, 2, 6, 1]输出:[2, 1, 1, 0]解释: 5 的右侧有 2 个更小的元素 (2 和 1)。 2 的右侧仅有 1 个更小的元素 (1)。 6 的右侧有 1 个更小的元素 (1)。 1 的右侧有 0 个更小的元素。
思路一:暴力,硬来
拿到这个题,最直观的想法是什么?就是老老实实按题目说的做。对于每个 nums[i],我直接写个循环,往它右边走,一个一个地数,看有多少个比它小。
// 暴力解法 O(N^2)
public List<Integer> countSmallerBruteForce(int[] nums) {
List<Integer> counts = new ArrayList<>();
for (int i = 0; i < nums.length; i++) {
int count = 0;
for (int j = i + 1; j < nums.length; j++) {
if (nums[j] < nums[i]) {
count++;
}
}
counts.add(count);
}
return counts;
}
这代码逻辑清晰,简单粗暴。但时间复杂度呢?两层循环,妥妥的 O(N^2)。如果 nums 的长度是 10^5,这直接就超时了。瓶颈在哪?瓶颈在于对于每个 i,我们都在做一次重复的、低效的扫描。
思路二:换个角度,用空间换时间
O(N^2) 的瓶颈在于,当我们考察 nums[i] 时,我们对它右边的信息一无所知,只能傻傻地去遍历。那我们能不能换个角度?
我们从右往左遍历 nums 数组。当我们处理 nums[i] 时,它右边的所有数字我们都已经“见过”了。现在的问题就变成了:在所有我们“见过”的数字里,有多少个比 nums[i] 小?
这一下就把问题转化成了一个动态的查询问题:
来一个数,把它“记录”下来。
查询一下,在所有已记录的数里,比当前数小的有多少个。
这不就是树状数组的拿手好戏吗?
树状数组可以高效地维护一个序列的前缀和。我们可以把数字本身看作是序列的下标。
当我们“见到”一个数
x,就执行add(x, 1)操作,表示x这个数出现了一次。当我们想查询比
y小的数有多少个,就执行query(y - 1)操作,这会返回[1, y-1]这个区间所有数的出现次数之和。
但是,这里有个坑: nums 里的数可能很大,也可能是负数,我们不能直接当数组下标。怎么办?离散化。
离散化说白了就是搞个排名。比如 [100, -5, 50],这三个数我们不关心它们具体的值,只关心它们的相对大小。排个序就是 [-5, 50, 100],我们给它们映射成 1, 2, 3。这样就把原始值域映射到了一个紧凑的、从1开始的整数区间,完美适配树状数组的下标。
所以,整体流程就出来了:
离散化:把
nums中所有不重复的数收集起来,排序,建立一个值 -> 排名的映射。初始化:创建一个大小为
排名总数 + 1的树状数组,初始全为0。从右往左遍历:
对于
nums[i],先通过映射找到它的排名rank。在树状数组中查询
query(rank - 1),得到的结果就是nums[i]右侧比它小的数的个数。存入结果数组。将
nums[i]加入到“已见过”的集合中,也就是在树状数组中执行add(rank, 1)。
这个思路下,每次查询和更新都是 O(logM),其中 M 是不同数字的个数。总的遍历是 O(N),所以总时间复杂度是 O(N logM),比 O(N^2) 不知道高到哪里去了。
public class CountSmaller {
private int[] tree;
private int size;
// 经典树状数组模板
private int lowbit(int x) {
return x & (-x);
}
private void add(int index, int val) {
while (index <= size) {
tree[index] += val;
index += lowbit(index);
}
}
private int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index -= lowbit(index);
}
return sum;
}
public List<Integer> countSmaller(int[] nums) {
if (nums == null || nums.length == 0) {
return new ArrayList<>();
}
// 1. 离散化
Set<Integer> uniqueNums = new HashSet<>();
for (int num : nums) {
uniqueNums.add(num);
}
List<Integer> sortedUniqueNums = new ArrayList<>(uniqueNums);
Collections.sort(sortedUniqueNums);
Map<Integer, Integer> rankMap = new HashMap<>();
int rank = 1;
for (Integer num : sortedUniqueNums) {
rankMap.put(num, rank++);
}
// 2. 初始化树状数组
this.size = uniqueNums.size();
this.tree = new int[size + 1];
// 3. 从右往左遍历
Integer[] result = new Integer[nums.length];
for (int i = nums.length - 1; i >= 0; i--) {
int currentRank = rankMap.get(nums[i]);
// 查询比当前数排名小的数的个数
result[i] = query(currentRank - 1);
// 将当前数加入树状数组
add(currentRank, 1);
}
return Arrays.asList(result);
}
}
区间和的个数
问题描述:
给定一个整数数组
nums和两个整数lower和upper,返回区间和在[lower, upper]范围内的个数。区间和
S(i, j)表示nums数组中从索引i到j(包含i和j)的元素的总和。示例:
输入:
nums = [-2, 5, -1],lower = -2,upper = 2输出:3解释: 存在 3 个区间和满足[lower, upper]的范围:[0,0],[2,2],[0,2]。它们对应的区间和分别为:-2,-1,2。
思路一:又是暴力
这个题求区间和,第一反应就是前缀和。我们先预处理一个前缀和数组 preSum,其中 preSum[i] 表示 nums[0...i-1] 的和。那么区间 [i, j] 的和就可以用 preSum[j+1] - preSum[i] 来 O(1) 计算。
然后,还是老套路,两层循环枚举所有的区间 [i, j],计算区间和,判断是否在 [lower, upper] 范围内。
// 暴力解法 O(N^2)
public int countRangeSumBruteForce(int[] nums, int lower, int upper) {
int n = nums.length;
long[] preSum = new long[n + 1];
for (int i = 0; i < n; i++) {
preSum[i + 1] = preSum[i] + nums[i];
}
int count = 0;
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
long sum = preSum[j + 1] - preSum[i];
if (sum >= lower && sum <= upper) {
count++;
}
}
}
return count;
}
逻辑没问题,但复杂度还是 O(N^2),肯定不是面试官想要的答案。
思路二:转换问题,再次请出树状数组
我们来分析一下 O(N^2) 的瓶颈。对于每个 j,我们都在遍历 i < j,然后计算 preSum[j+1] - preSum[i]。
我们把前缀和数组设为 s,s[i] 表示 nums[0..i-1] 的和。 那么区间和 S(i, j) 就是 s[j+1] - s[i]。 我们要找的是 lower <= s[j+1] - s[i] <= upper。
对这个不等式变形,我们固定 j,要找满足条件的 i 的个数: s[j+1] - upper <= s[i] <= s[j+1] - lower
这下问题又变了!当我们遍历到 j 时,我们需要做的是:在 j 之前出现过的所有前缀和 s[0], s[1], ..., s[j] 中,有多少个落在了 [s[j+1] - upper, s[j+1] - lower] 这个区间里?
这不又回到了我们熟悉的路子上了吗?
遍历前缀和数组
s。对于当前的前缀和
s_current,我们需要查询在它之前出现过的前缀和中,有多少个值在[s_current - upper, s_current - lower]这个范围里。然后把
s_current也加入到“已见过”的集合里。
查询一个范围内的数的个数,可以用树状数组 query(end) - query(start - 1) 来实现。
同样,前缀和的值域可能非常大,而且不连续,所以离散化再次登场。这次需要离散化的值有哪些? 对于每一个前缀和 s_k,我们都关心 s_k, s_k - lower 和 s_k - upper 这三个值。所以,我们把所有可能出现的前缀和,以及由它们衍生出的查询边界 s-lower 和 s-upper,全部收集起来,进行离散化。
流程:
计算前缀和数组
preSum。离散化:收集所有的
preSum[k],preSum[k]-lower,preSum[k]-upper,排序去重,建立值 -> 排名映射。初始化树状数组。
遍历前缀和数组
preSum(从preSum[0]开始,它代表空前缀,值为0):对于
preSum[i]:计算查询区间的上下界:
L = preSum[i] - upper,R = preSum[i] - lower。找到
L和R对应的排名rank_L和rank_R。在树状数组中查询
query(rank_R) - query(rank_L - 1),累加到最终结果count中。找到
preSum[i]自己的排名rank_i,在树状数组中执行add(rank_i, 1)。
这里要注意,我们是先查询,再添加。因为题目要求 i <= j,当我们处理 preSum[j+1] 时,我们要找的是 preSum[i] (i <= j)。所以在处理 preSum[k] 时,我们找的是已经处理过的 preSum[0]...preSum[k-1],这就保证了 i 在 j 的前面。
public class CountRangeSum {
private int[] tree;
private int size;
// 树状数组模板...
private int lowbit(int x) { return x & (-x); }
private void add(int index, int val) { /* ... */ }
private int query(int index) { /* ... */ }
public int countRangeSum(int[] nums, int lower, int upper) {
int n = nums.length;
long[] preSum = new long[n + 1];
for (int i = 0; i < n; i++) {
preSum[i + 1] = preSum[i] + nums[i];
}
// 1. 离散化
Set<Long> allNumbers = new HashSet<>();
for (long s : preSum) {
allNumbers.add(s);
allNumbers.add(s - lower);
allNumbers.add(s - upper);
}
List<Long> sortedNumbers = new ArrayList<>(allNumbers);
Collections.sort(sortedNumbers);
Map<Long, Integer> rankMap = new HashMap<>();
int rank = 1;
for (Long num : sortedNumbers) {
rankMap.put(num, rank++);
}
// 2. 初始化树状数组
this.size = allNumbers.size();
this.tree = new int[size + 1];
// 3. 遍历前缀和
int count = 0;
// 初始时,有一个前缀和0存在
add(rankMap.get(0L), 1);
for (int i = 1; i <= n; i++) {
long currentSum = preSum[i];
long L = currentSum - upper;
long R = currentSum - lower;
// 查询 [L, R] 范围内的前缀和个数
// 注意这里rankmap可能不包含L-1,需要找小于L的最大值的rank
// 一个简单的方式是直接求 query(rank_R) - query(rank_L - 1)
Integer rankR = rankMap.get(R);
Integer rankL = rankMap.get(L);
// query(end) - query(start - 1)
count += query(rankR) - query(rankL - 1);
// 把当前的前缀和加入
add(rankMap.get(currentSum), 1);
}
return count;
}
// 完整的 add 和 query 方法
private void add(int index, int val) {
while (index <= size) {
tree[index] += val;
index += lowbit(index);
}
}
private int query(int index) {
int sum = 0;
// rankMap.get()可能返回null,如果查询的边界值本身不在离散化集合里,
// 比如L-1,我们需要找到不大于L-1的最大值的排名,这个处理比较复杂。
// 一个更严谨(但代码更复杂)的做法是二分查找排名。
// 但对于本题,因为所有相关的边界都已加入离-散化集合,可以直接用。
if (index < 1) return 0;
while (index > 0) {
sum += tree[index];
index -= lowbit(index);
}
return sum;
}
}
注:上面代码 query(rankL-1) 是一个简化写法,如果 L 是离散化后的最小值,L-1 就不在map里。严谨的实现需要二分查找 L 在 sortedNumbers 中的位置。但由于我们把所有 s, s-lower, s-upper 都加入了,这里的 L 和 R 一定在 rankMap 中,所以可以直接用。 query(rankL - 1) 正确表达了小于 rankL 的所有排名的和。
这个解法的时间复杂度是 O(N logN),因为离散化排序是 O(N logN),遍历和树状数组操作也是 O(N logN)。空间复杂度是 O(N) 用来存各种值和树状数组。完美解决。
k 个关闭的灯泡
问题描述:
有 n 个灯泡,从 1 到 n 编号。最初所有灯泡都是关闭的。
每天,我们会打开一个指定的灯泡。给你一个数组
flowers,其中flowers[i] = x表示在第i+1天,我们会打开位置为x的灯泡。请返回在哪一天,存在两个相邻的已打开的灯泡,且它们之间正好有
k个未打开的灯泡。如果不存在这样的一天,返回 -1。
示例 1:
输入:
flowers = [1,3,2],k = 1输出:2解释: 第 1 天,打开位置 1 的灯泡 [1,0,0]。 第 2 天,打开位置 3 的灯泡 [1,0,1]。 第 3 天,打开位置 2 的灯泡 [1,1,1]。在第 2 天,灯泡 1 和 3 之间有 1 个关闭的灯泡。示例 2:
输入:
flowers = [1,2,3],k = 1输出:-1
思路一:模拟每一天,暴力检查
这个题直接模拟也行。我们用一个数组 status 来表示灯泡的状态,0表示关,1表示开。 每天打开一个灯泡后,就从头到尾扫一遍 status 数组,找有没有 status[i]=1, status[i+k+1]=1 的情况。
这个思路太暴力了,每天都要 O(N) 检查,总共 N 天,又是 O(N^2)。pass。
思路二:树状数组优化检查过程
我们分析一下,当第 i 天,位置为 p = flowers[i] 的灯泡打开时,我们需要检查什么? 我们需要看 p 的两个“邻居”:位置在 p - k - 1 和 p + k + 1 的灯泡。
如果
p - k - 1这个位置的灯泡是亮的,那我们就要看(p-k-1, p)这个区间内是不是空的(没有其他亮灯)。如果
p + k + 1这个位置的灯泡是亮的,那我们就要看(p, p+k+1)这个区间内是不是空的。
这就又转化为区间查询问题了! 我们可以用一个树状数组来维护亮灯的位置。 add(pos, 1) 表示 pos 位置的灯亮了。 query(end) - query(start) 就能得到 (start, end] 区间内亮灯的数量。
流程:
创建一个长度为
n+1的树状数组tree和一个布尔数组is_on,记录每个位置的灯是否亮。按天数遍历
flowers数组,从day = 1到n。在第
day天,灯泡p = flowers[day-1]打开。检查左邻居
left = p - k - 1:如果
left >= 1并且is_on[left]为true。查询
(left, p)区间,即query(p-1) - query(left)。如果结果为0,说明中间没有灯,找到了答案,返回day。
检查右邻居
right = p + k + 1:如果
right <= n并且is_on[right]为true。查询
(p, right)区间,即query(right-1) - query(p)。如果结果为0,找到了答案,返回day。
检查完毕后,更新状态:
is_on[p] = true,并add(p, 1)。如果循环结束都没找到,返回 -1。
这个解法,每天的操作主要是两次查询和一次更新,都是 O(logN) 的。总复杂度 O(N logN)。
public class KEmptySlots {
private int[] tree;
private int n;
// 树状数组模板...
private int lowbit(int x) { return x & (-x); }
private void add(int index, int val) { /* ... */ }
private int query(int index) { /* ... */ }
public int kEmptySlots(int[] bulbs, int k) {
this.n = bulbs.length;
if (n == 0 || k < 0) {
return -1;
}
this.tree = new int[n + 1];
boolean[] is_on = new boolean[n + 1];
for (int day = 0; day < n; day++) {
int pos = bulbs[day];
// 检查左边
int left = pos - k - 1;
if (left >= 1 && is_on[left]) {
// 查询 (left, pos) 区间
if (query(pos - 1) - query(left) == 0) {
return day + 1;
}
}
// 检查右边
int right = pos + k + 1;
if (right <= n && is_on[right]) {
// 查询 (pos, right) 区间
if (query(right - 1) - query(pos) == 0) {
return day + 1;
}
}
// 更新状态
is_on[pos] = true;
add(pos, 1);
}
return -1;
}
// 完整的 add 和 query 方法
private void add(int index, int val) {
while (index <= n) {
tree[index] += val;
index += lowbit(index);
}
}
private int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index -= lowbit(index);
}
return sum;
}
}
其实这个题还有个更巧妙的 O(N) 解法,利用滑动窗口。思路是反过来,先建一个 days 数组,days[p] 表示位置 p 的灯是在第几天亮的。问题就变成了找一个 i,使得 max(days[i], days[i+k+1]) 最小,并且对于所有 j 满足 i < j < i+k+1 都有 days[j] > max(days[i], days[i+k+1])。这是一个可以用滑动窗口解决的问题。但既然我们今天的主题是树状数组,O(N logN) 的解法已经相当不错了。
二维区域和检索 - 矩阵可修改
问题描述:
给你一个 2D 矩阵
matrix,请实现一个NumMatrix类,支持以下两种操作:
update(row, col, val):将matrix[row][col]的值更新为val。
sumRegion(row1, col1, row2, col2):返回矩阵中左上角为(row1, col1)、右下角为(row2, col2)的子矩阵的元素总和。示例:
NumMatrix numMatrix = new NumMatrix([[3, 0, 1, 4, 2], [5, 6, 3, 2, 1], [1, 2, 0, 1, 5], [4, 1, 0, 1, 7], [1, 0, 3, 0, 5]]); numMatrix.sumRegion(2, 1, 4, 3); // 返回 8 numMatrix.update(3, 2, 2); numMatrix.sumRegion(2, 1, 4, 3); // 返回 10
思路一:朴素解法们
直接用二维数组:
update是 O(1),sumRegion是 O(R*C),其中 R,C 是子矩阵的长宽。太慢。二维前缀和:
sumRegion可以用容斥原理做到 O(1)。但是update(r, c, val)会导致从(r, c)开始到右下角的所有前缀和都失效,更新成本是 O(M*N)。还是很慢。
我们需要一个能在 update 和 sumRegion 之间取得平衡的结构。
思路二:二维树状数组
一维的树状数组可以处理一维序列的单点更新和前缀和查询。那二维的自然就想到用二维树状数组。
怎么理解二维树状数组?很简单,就是“树状数组的树状数组”。 我们可以把 tree[i] 不看作一个数字,而是看作另一个一维的树状数组。 tree[i] 管理的是第 i 行的某个信息聚合,而tree[i] 本身这个树状数组,tree[i][j] 管理的是第 i 行第 j 列的某个信息聚合。
操作定义:
add(x, y, delta):给matrix[x][y]增加delta。对于
i从x到M(每次i += lowbit(i)):对于
j从y到N(每次j += lowbit(j)):tree[i][j] += delta
query(x, y):查询从(1,1)到(x,y)的矩阵和。sum = 0对于
i从x到 1 (每次i -= lowbit(i)):对于
j从y到 1 (每次j -= lowbit(j)):sum += tree[i][j]
return sum
update 操作的 val 是新值,而树状数组的 add 方法是增加一个差值 delta。所以我们还需要一个原始矩阵的拷贝来计算 delta = val - matrix[row][col]。
NumMatrix(matrix)构造函数:保存
matrix的一份拷贝。初始化一个
M+1行N+1列的二维树状数组tree。遍历
matrix,对每个matrix[i][j]调用add(i+1, j+1, matrix[i][j])来构建tree。
update(row, col, val):计算差值
delta = val - matrix_copy[row][col]。更新
matrix_copy[row][col] = val。调用
add(row+1, col+1, delta)。
sumRegion(r1, c1, r2, c2): 利用query和容斥原理:sum = query(r2+1, c2+1) - query(r1, c2+1) - query(r2+1, c1) + query(r1, c1)。 注意下标转换,我们的树状数组是从1开始的。
时间复杂度:update 和 query 都是 O(logM logN)。 空间复杂度:O(MN)。
class NumMatrix {
private int[][] tree;
private int[][] matrix;
private int rows;
private int cols;
public NumMatrix(int[][] matrix) {
if (matrix == null || matrix.length == 0 || matrix[0].length == 0) {
return;
}
this.rows = matrix.length;
this.cols = matrix[0].length;
this.matrix = new int[rows][cols];
this.tree = new int[rows + 1][cols + 1];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
update(i, j, matrix[i][j]);
}
}
}
public void update(int row, int col, int val) {
int delta = val - this.matrix[row][col];
this.matrix[row][col] = val;
// 树状数组下标从1开始
for (int i = row + 1; i <= rows; i += i & -i) {
for (int j = col + 1; j <= cols; j += j & -j) {
tree[i][j] += delta;
}
}
}
// query返回 (0,0) 到 (row-1, col-1) 的和
private int query(int row, int col) {
int sum = 0;
for (int i = row; i > 0; i -= i & -i) {
for (int j = col; j > 0; j -= j & -j) {
sum += tree[i][j];
}
}
return sum;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
// 利用容斥原理
// query的参数是树状数组的下标,比原始矩阵下标大1
return query(row2 + 1, col2 + 1)
- query(row1, col2 + 1)
- query(row2 + 1, col1)
+ query(row1, col1);
}
}