树状数组( Binary Indexed Tree)

MrHe··10 min read

树状数组(Fenwick Tree / Binary Indexed Tree)这个东西。这玩意儿听起来挺唬人,但本质上就是个用数组模拟树形结构的工具,专门用来解决一类问题:频繁的单点更新和区间查询。如果我们用普通数组,更新是 O(1),但查区间和就是 O(N);用前缀和数组,查区间和是 O(1),但你更新一个点,后面所有的前缀和都得改,又是 O(N)。两边都不讨好。树状数组就把这两个操作都干到了 O(logN),一下就平衡了。

废话不多说,咱们从题入手,看看这东西到底怎么用,思路是怎么一步步升级的。

计算右侧小于当前元素的个数

问题描述:

给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

示例:

输入: nums = [5, 2, 6, 1] 输出: [2, 1, 1, 0] 解释: 5 的右侧有 2 个更小的元素 (2 和 1)。 2 的右侧仅有 1 个更小的元素 (1)。 6 的右侧有 1 个更小的元素 (1)。 1 的右侧有 0 个更小的元素。


思路一:暴力,硬来

拿到这个题,最直观的想法是什么?就是老老实实按题目说的做。对于每个 nums[i],我直接写个循环,往它右边走,一个一个地数,看有多少个比它小。

// 暴力解法 O(N^2)
public List<Integer> countSmallerBruteForce(int[] nums) {
    List<Integer> counts = new ArrayList<>();
    for (int i = 0; i < nums.length; i++) {
        int count = 0;
        for (int j = i + 1; j < nums.length; j++) {
            if (nums[j] < nums[i]) {
                count++;
            }
        }
        counts.add(count);
    }
    return counts;
}

这代码逻辑清晰,简单粗暴。但时间复杂度呢?两层循环,妥妥的 O(N^2)。如果 nums 的长度是 10^5,这直接就超时了。瓶颈在哪?瓶颈在于对于每个 i,我们都在做一次重复的、低效的扫描。

思路二:换个角度,用空间换时间

O(N^2) 的瓶颈在于,当我们考察 nums[i] 时,我们对它右边的信息一无所知,只能傻傻地去遍历。那我们能不能换个角度?

我们从右往左遍历 nums 数组。当我们处理 nums[i] 时,它右边的所有数字我们都已经“见过”了。现在的问题就变成了:在所有我们“见过”的数字里,有多少个比 nums[i] 小?

这一下就把问题转化成了一个动态的查询问题:

  1. 来一个数,把它“记录”下来。

  2. 查询一下,在所有已记录的数里,比当前数小的有多少个。

这不就是树状数组的拿手好戏吗?

树状数组可以高效地维护一个序列的前缀和。我们可以把数字本身看作是序列的下标。

  • 当我们“见到”一个数 x,就执行 add(x, 1) 操作,表示 x 这个数出现了一次。

  • 当我们想查询比 y 小的数有多少个,就执行 query(y - 1) 操作,这会返回 [1, y-1] 这个区间所有数的出现次数之和。

但是,这里有个坑nums 里的数可能很大,也可能是负数,我们不能直接当数组下标。怎么办?离散化

离散化说白了就是搞个排名。比如 [100, -5, 50],这三个数我们不关心它们具体的值,只关心它们的相对大小。排个序就是 [-5, 50, 100],我们给它们映射成 1, 2, 3。这样就把原始值域映射到了一个紧凑的、从1开始的整数区间,完美适配树状数组的下标。

所以,整体流程就出来了:

  1. 离散化:把 nums 中所有不重复的数收集起来,排序,建立一个 值 -> 排名 的映射。

  2. 初始化:创建一个大小为 排名总数 + 1 的树状数组,初始全为0。

  3. 从右往左遍历

    • 对于 nums[i],先通过映射找到它的排名 rank

    • 在树状数组中查询 query(rank - 1),得到的结果就是 nums[i] 右侧比它小的数的个数。存入结果数组。

    • nums[i] 加入到“已见过”的集合中,也就是在树状数组中执行 add(rank, 1)

这个思路下,每次查询和更新都是 O(logM),其中 M 是不同数字的个数。总的遍历是 O(N),所以总时间复杂度是 O(N logM),比 O(N^2) 不知道高到哪里去了。

public class CountSmaller {

    private int[] tree;
    private int size;

    // 经典树状数组模板
    private int lowbit(int x) {
        return x & (-x);
    }

    private void add(int index, int val) {
        while (index <= size) {
            tree[index] += val;
            index += lowbit(index);
        }
    }

    private int query(int index) {
        int sum = 0;
        while (index > 0) {
            sum += tree[index];
            index -= lowbit(index);
        }
        return sum;
    }

    public List<Integer> countSmaller(int[] nums) {
        if (nums == null || nums.length == 0) {
            return new ArrayList<>();
        }

        // 1. 离散化
        Set<Integer> uniqueNums = new HashSet<>();
        for (int num : nums) {
            uniqueNums.add(num);
        }
        List<Integer> sortedUniqueNums = new ArrayList<>(uniqueNums);
        Collections.sort(sortedUniqueNums);

        Map<Integer, Integer> rankMap = new HashMap<>();
        int rank = 1;
        for (Integer num : sortedUniqueNums) {
            rankMap.put(num, rank++);
        }

        // 2. 初始化树状数组
        this.size = uniqueNums.size();
        this.tree = new int[size + 1];

        // 3. 从右往左遍历
        Integer[] result = new Integer[nums.length];
        for (int i = nums.length - 1; i >= 0; i--) {
            int currentRank = rankMap.get(nums[i]);
            // 查询比当前数排名小的数的个数
            result[i] = query(currentRank - 1);
            // 将当前数加入树状数组
            add(currentRank, 1);
        }

        return Arrays.asList(result);
    }
}

区间和的个数

问题描述:

给定一个整数数组 nums 和两个整数 lowerupper,返回区间和在 [lower, upper] 范围内的个数。

区间和 S(i, j) 表示 nums 数组中从索引 ij(包含 ij)的元素的总和。

示例:

输入: nums = [-2, 5, -1], lower = -2, upper = 2 输出: 3 解释: 存在 3 个区间和满足 [lower, upper] 的范围: [0,0], [2,2], [0,2]。它们对应的区间和分别为: -2, -1, 2


思路一:又是暴力

这个题求区间和,第一反应就是前缀和。我们先预处理一个前缀和数组 preSum,其中 preSum[i] 表示 nums[0...i-1] 的和。那么区间 [i, j] 的和就可以用 preSum[j+1] - preSum[i] 来 O(1) 计算。

然后,还是老套路,两层循环枚举所有的区间 [i, j],计算区间和,判断是否在 [lower, upper] 范围内。

// 暴力解法 O(N^2)
public int countRangeSumBruteForce(int[] nums, int lower, int upper) {
    int n = nums.length;
    long[] preSum = new long[n + 1];
    for (int i = 0; i < n; i++) {
        preSum[i + 1] = preSum[i] + nums[i];
    }

    int count = 0;
    for (int i = 0; i < n; i++) {
        for (int j = i; j < n; j++) {
            long sum = preSum[j + 1] - preSum[i];
            if (sum >= lower && sum <= upper) {
                count++;
            }
        }
    }
    return count;
}

逻辑没问题,但复杂度还是 O(N^2),肯定不是面试官想要的答案。

思路二:转换问题,再次请出树状数组

我们来分析一下 O(N^2) 的瓶颈。对于每个 j,我们都在遍历 i < j,然后计算 preSum[j+1] - preSum[i]

我们把前缀和数组设为 ss[i] 表示 nums[0..i-1] 的和。 那么区间和 S(i, j) 就是 s[j+1] - s[i]。 我们要找的是 lower <= s[j+1] - s[i] <= upper

对这个不等式变形,我们固定 j,要找满足条件的 i 的个数: s[j+1] - upper <= s[i] <= s[j+1] - lower

这下问题又变了!当我们遍历到 j 时,我们需要做的是:在 j 之前出现过的所有前缀和 s[0], s[1], ..., s[j] 中,有多少个落在了 [s[j+1] - upper, s[j+1] - lower] 这个区间里?

这不又回到了我们熟悉的路子上了吗?

  1. 遍历前缀和数组 s

  2. 对于当前的前缀和 s_current,我们需要查询在它之前出现过的前缀和中,有多少个值在 [s_current - upper, s_current - lower] 这个范围里。

  3. 然后把 s_current 也加入到“已见过”的集合里。

查询一个范围内的数的个数,可以用树状数组 query(end) - query(start - 1) 来实现。

同样,前缀和的值域可能非常大,而且不连续,所以离散化再次登场。这次需要离散化的值有哪些? 对于每一个前缀和 s_k,我们都关心 s_ks_k - lowers_k - upper 这三个值。所以,我们把所有可能出现的前缀和,以及由它们衍生出的查询边界 s-lowers-upper,全部收集起来,进行离散化。

流程:

  1. 计算前缀和数组 preSum

  2. 离散化:收集所有的 preSum[k], preSum[k]-lower, preSum[k]-upper,排序去重,建立 值 -> 排名 映射。

  3. 初始化树状数组。

  4. 遍历前缀和数组 preSum(从 preSum[0] 开始,它代表空前缀,值为0):

    • 对于 preSum[i]

      • 计算查询区间的上下界: L = preSum[i] - upperR = preSum[i] - lower

      • 找到 LR 对应的排名 rank_Lrank_R

      • 在树状数组中查询 query(rank_R) - query(rank_L - 1),累加到最终结果 count 中。

      • 找到 preSum[i] 自己的排名 rank_i,在树状数组中执行 add(rank_i, 1)

这里要注意,我们是先查询,再添加。因为题目要求 i <= j,当我们处理 preSum[j+1] 时,我们要找的是 preSum[i] (i <= j)。所以在处理 preSum[k] 时,我们找的是已经处理过的 preSum[0]...preSum[k-1],这就保证了 ij 的前面。

public class CountRangeSum {

    private int[] tree;
    private int size;

    // 树状数组模板...
    private int lowbit(int x) { return x & (-x); }
    private void add(int index, int val) { /* ... */ }
    private int query(int index) { /* ... */ }


    public int countRangeSum(int[] nums, int lower, int upper) {
        int n = nums.length;
        long[] preSum = new long[n + 1];
        for (int i = 0; i < n; i++) {
            preSum[i + 1] = preSum[i] + nums[i];
        }

        // 1. 离散化
        Set<Long> allNumbers = new HashSet<>();
        for (long s : preSum) {
            allNumbers.add(s);
            allNumbers.add(s - lower);
            allNumbers.add(s - upper);
        }
        List<Long> sortedNumbers = new ArrayList<>(allNumbers);
        Collections.sort(sortedNumbers);
        Map<Long, Integer> rankMap = new HashMap<>();
        int rank = 1;
        for (Long num : sortedNumbers) {
            rankMap.put(num, rank++);
        }

        // 2. 初始化树状数组
        this.size = allNumbers.size();
        this.tree = new int[size + 1];

        // 3. 遍历前缀和
        int count = 0;
        // 初始时,有一个前缀和0存在
        add(rankMap.get(0L), 1); 

        for (int i = 1; i <= n; i++) {
            long currentSum = preSum[i];
            long L = currentSum - upper;
            long R = currentSum - lower;

            // 查询 [L, R] 范围内的前缀和个数
            // 注意这里rankmap可能不包含L-1,需要找小于L的最大值的rank
            // 一个简单的方式是直接求 query(rank_R) - query(rank_L - 1)
            Integer rankR = rankMap.get(R);
            Integer rankL = rankMap.get(L);

            // query(end) - query(start - 1)
            count += query(rankR) - query(rankL - 1);

            // 把当前的前缀和加入
            add(rankMap.get(currentSum), 1);
        }
        return count;
    }

    // 完整的 add 和 query 方法
    private void add(int index, int val) {
        while (index <= size) {
            tree[index] += val;
            index += lowbit(index);
        }
    }

    private int query(int index) {
        int sum = 0;
        // rankMap.get()可能返回null,如果查询的边界值本身不在离散化集合里,
        // 比如L-1,我们需要找到不大于L-1的最大值的排名,这个处理比较复杂。
        // 一个更严谨(但代码更复杂)的做法是二分查找排名。
        // 但对于本题,因为所有相关的边界都已加入离-散化集合,可以直接用。
        if (index < 1) return 0;
        while (index > 0) {
            sum += tree[index];
            index -= lowbit(index);
        }
        return sum;
    }
}

注:上面代码 query(rankL-1) 是一个简化写法,如果 L 是离散化后的最小值,L-1 就不在map里。严谨的实现需要二分查找 LsortedNumbers 中的位置。但由于我们把所有 s, s-lower, s-upper 都加入了,这里的 LR 一定在 rankMap 中,所以可以直接用。 query(rankL - 1) 正确表达了小于 rankL 的所有排名的和。

这个解法的时间复杂度是 O(N logN),因为离散化排序是 O(N logN),遍历和树状数组操作也是 O(N logN)。空间复杂度是 O(N) 用来存各种值和树状数组。完美解决。


k 个关闭的灯泡

问题描述:

有 n 个灯泡,从 1 到 n 编号。最初所有灯泡都是关闭的。

每天,我们会打开一个指定的灯泡。给你一个数组 flowers,其中 flowers[i] = x 表示在第 i+1 天,我们会打开位置为 x 的灯泡。

请返回在哪一天,存在两个相邻的已打开的灯泡,且它们之间正好有 k 个未打开的灯泡。

如果不存在这样的一天,返回 -1。

示例 1:

输入: flowers = [1,3,2], k = 1 输出: 2 解释: 第 1 天,打开位置 1 的灯泡 [1,0,0]。 第 2 天,打开位置 3 的灯泡 [1,0,1]。 第 3 天,打开位置 2 的灯泡 [1,1,1]。在第 2 天,灯泡 1 和 3 之间有 1 个关闭的灯泡。

示例 2:

输入: flowers = [1,2,3], k = 1 输出: -1


思路一:模拟每一天,暴力检查

这个题直接模拟也行。我们用一个数组 status 来表示灯泡的状态,0表示关,1表示开。 每天打开一个灯泡后,就从头到尾扫一遍 status 数组,找有没有 status[i]=1, status[i+k+1]=1 的情况。

这个思路太暴力了,每天都要 O(N) 检查,总共 N 天,又是 O(N^2)。pass。

思路二:树状数组优化检查过程

我们分析一下,当第 i 天,位置为 p = flowers[i] 的灯泡打开时,我们需要检查什么? 我们需要看 p 的两个“邻居”:位置在 p - k - 1p + k + 1 的灯泡。

  1. 如果 p - k - 1 这个位置的灯泡是亮的,那我们就要看 (p-k-1, p) 这个区间内是不是空的(没有其他亮灯)。

  2. 如果 p + k + 1 这个位置的灯泡是亮的,那我们就要看 (p, p+k+1) 这个区间内是不是空的。

这就又转化为区间查询问题了! 我们可以用一个树状数组来维护亮灯的位置add(pos, 1) 表示 pos 位置的灯亮了。 query(end) - query(start) 就能得到 (start, end] 区间内亮灯的数量。

流程:

  1. 创建一个长度为 n+1 的树状数组 tree 和一个布尔数组 is_on,记录每个位置的灯是否亮。

  2. 按天数遍历 flowers 数组,从 day = 1n

  3. 在第 day 天,灯泡 p = flowers[day-1] 打开。

  4. 检查左邻居 left = p - k - 1

    • 如果 left >= 1 并且 is_on[left]true

    • 查询 (left, p) 区间,即 query(p-1) - query(left)。如果结果为0,说明中间没有灯,找到了答案,返回 day

  5. 检查右邻居 right = p + k + 1

    • 如果 right <= n 并且 is_on[right]true

    • 查询 (p, right) 区间,即 query(right-1) - query(p)。如果结果为0,找到了答案,返回 day

  6. 检查完毕后,更新状态:is_on[p] = true,并 add(p, 1)

  7. 如果循环结束都没找到,返回 -1。

这个解法,每天的操作主要是两次查询和一次更新,都是 O(logN) 的。总复杂度 O(N logN)。

public class KEmptySlots {

    private int[] tree;
    private int n;

    // 树状数组模板...
    private int lowbit(int x) { return x & (-x); }
    private void add(int index, int val) { /* ... */ }
    private int query(int index) { /* ... */ }

    public int kEmptySlots(int[] bulbs, int k) {
        this.n = bulbs.length;
        if (n == 0 || k < 0) {
            return -1;
        }
        this.tree = new int[n + 1];
        boolean[] is_on = new boolean[n + 1];

        for (int day = 0; day < n; day++) {
            int pos = bulbs[day];

            // 检查左边
            int left = pos - k - 1;
            if (left >= 1 && is_on[left]) {
                // 查询 (left, pos) 区间
                if (query(pos - 1) - query(left) == 0) {
                    return day + 1;
                }
            }

            // 检查右边
            int right = pos + k + 1;
            if (right <= n && is_on[right]) {
                // 查询 (pos, right) 区间
                if (query(right - 1) - query(pos) == 0) {
                    return day + 1;
                }
            }

            // 更新状态
            is_on[pos] = true;
            add(pos, 1);
        }

        return -1;
    }

    // 完整的 add 和 query 方法
    private void add(int index, int val) {
        while (index <= n) {
            tree[index] += val;
            index += lowbit(index);
        }
    }

    private int query(int index) {
        int sum = 0;
        while (index > 0) {
            sum += tree[index];
            index -= lowbit(index);
        }
        return sum;
    }
}

其实这个题还有个更巧妙的 O(N) 解法,利用滑动窗口。思路是反过来,先建一个 days 数组,days[p] 表示位置 p 的灯是在第几天亮的。问题就变成了找一个 i,使得 max(days[i], days[i+k+1]) 最小,并且对于所有 j 满足 i < j < i+k+1 都有 days[j] > max(days[i], days[i+k+1])。这是一个可以用滑动窗口解决的问题。但既然我们今天的主题是树状数组,O(N logN) 的解法已经相当不错了。


二维区域和检索 - 矩阵可修改

问题描述:

给你一个 2D 矩阵 matrix,请实现一个 NumMatrix 类,支持以下两种操作:

  1. update(row, col, val):将 matrix[row][col] 的值更新为 val

  2. sumRegion(row1, col1, row2, col2):返回矩阵中左上角为 (row1, col1)、右下角为 (row2, col2) 的子矩阵的元素总和。

示例:

NumMatrix numMatrix = new NumMatrix([[3, 0, 1, 4, 2], [5, 6, 3, 2, 1], [1, 2, 0, 1, 5], [4, 1, 0, 1, 7], [1, 0, 3, 0, 5]]);
numMatrix.sumRegion(2, 1, 4, 3); // 返回 8
numMatrix.update(3, 2, 2);
numMatrix.sumRegion(2, 1, 4, 3); // 返回 10

思路一:朴素解法们

  1. 直接用二维数组update 是 O(1),sumRegion 是 O(R*C),其中 R,C 是子矩阵的长宽。太慢。

  2. 二维前缀和sumRegion 可以用容斥原理做到 O(1)。但是 update(r, c, val) 会导致从 (r, c) 开始到右下角的所有前缀和都失效,更新成本是 O(M*N)。还是很慢。

我们需要一个能在 updatesumRegion 之间取得平衡的结构。

思路二:二维树状数组

一维的树状数组可以处理一维序列的单点更新和前缀和查询。那二维的自然就想到用二维树状数组。

怎么理解二维树状数组?很简单,就是“树状数组的树状数组”。 我们可以把 tree[i] 不看作一个数字,而是看作另一个一维的树状数组。 tree[i] 管理的是第 i 行的某个信息聚合,而tree[i] 本身这个树状数组,tree[i][j] 管理的是第 i 行第 j 列的某个信息聚合。

操作定义:

  • add(x, y, delta):给 matrix[x][y] 增加 delta

    • 对于 ixM (每次 i += lowbit(i)):

      • 对于 jyN (每次 j += lowbit(j)):

        • tree[i][j] += delta
  • query(x, y):查询从 (1,1)(x,y) 的矩阵和。

    • sum = 0

    • 对于 ix 到 1 (每次 i -= lowbit(i)):

      • 对于 jy 到 1 (每次 j -= lowbit(j)):

        • sum += tree[i][j]
    • return sum

update 操作的 val 是新值,而树状数组的 add 方法是增加一个差值 delta。所以我们还需要一个原始矩阵的拷贝来计算 delta = val - matrix[row][col]

  • NumMatrix(matrix) 构造函数:

    1. 保存 matrix 的一份拷贝。

    2. 初始化一个 M+1N+1 列的二维树状数组 tree

    3. 遍历 matrix,对每个 matrix[i][j] 调用 add(i+1, j+1, matrix[i][j]) 来构建 tree

  • update(row, col, val)

    1. 计算差值 delta = val - matrix_copy[row][col]

    2. 更新 matrix_copy[row][col] = val

    3. 调用 add(row+1, col+1, delta)

  • sumRegion(r1, c1, r2, c2): 利用 query 和容斥原理: sum = query(r2+1, c2+1) - query(r1, c2+1) - query(r2+1, c1) + query(r1, c1)。 注意下标转换,我们的树状数组是从1开始的。

时间复杂度:updatequery 都是 O(logM logN)。 空间复杂度:O(MN)。

class NumMatrix {

    private int[][] tree;
    private int[][] matrix;
    private int rows;
    private int cols;

    public NumMatrix(int[][] matrix) {
        if (matrix == null || matrix.length == 0 || matrix[0].length == 0) {
            return;
        }
        this.rows = matrix.length;
        this.cols = matrix[0].length;
        this.matrix = new int[rows][cols];
        this.tree = new int[rows + 1][cols + 1];

        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                update(i, j, matrix[i][j]);
            }
        }
    }

    public void update(int row, int col, int val) {
        int delta = val - this.matrix[row][col];
        this.matrix[row][col] = val;
        // 树状数组下标从1开始
        for (int i = row + 1; i <= rows; i += i & -i) {
            for (int j = col + 1; j <= cols; j += j & -j) {
                tree[i][j] += delta;
            }
        }
    }

    // query返回 (0,0) 到 (row-1, col-1) 的和
    private int query(int row, int col) {
        int sum = 0;
        for (int i = row; i > 0; i -= i & -i) {
            for (int j = col; j > 0; j -= j & -j) {
                sum += tree[i][j];
            }
        }
        return sum;
    }

    public int sumRegion(int row1, int col1, int row2, int col2) {
        // 利用容斥原理
        // query的参数是树状数组的下标,比原始矩阵下标大1
        return query(row2 + 1, col2 + 1) 
             - query(row1, col2 + 1) 
             - query(row2 + 1, col1) 
             + query(row1, col1);
    }
}