堆(优先队列)

MrHe··4 min read

写了这么久代码,我发现很多数据结构本质上都是为了解决特定场景下的效率问题。数组增删慢、查询快;链表增删快、查询慢。那如果我有个需求:动态地、快速地找到一群数据里的最大值或最小值,该用啥?

你可能会说,用个数组存着,每次找最大值就遍历一遍,O(N) 嘛。可以,但如果这个操作要进行一万次呢?效率太低了。如果我每次插入新数据后都排序呢?那插入一个数就得 O(N*logN),更慢。

这时候,堆(Heap)就该登场了。它就是为了这个场景而生的。

一、掰开揉碎,啥是堆?

说白了,堆就是一个特殊的树。但你别被“树”吓到,它在物理存储上,通常就是一个数组。我们只是在逻辑上把它看成一棵完全二叉树

它有两个核心性质:

  1. 结构性:它是一棵完全二叉-树。这意味着除了最后一层,其他层都是满的,并且最后一层的节点都尽量靠左排列。这个性质保证了我们可以用数组来高效地表示它,没有空间浪费。

  2. 堆序性:所有父节点的值都必须 大于等于 (或 小于等于) 其所有子节点的值。

    • 父节点总比子节点大的,叫大根堆 (Max-Heap)。堆顶是整个堆的最大值。

    • 父节点总比子节点小的,叫小根堆 (Min-Heap)。堆顶是整个堆的最小值。

用数组怎么表示一棵完全二-叉树?

这个是关键。假设数组下标从 0 开始,对于任意一个位置 i 的节点:

  • 它的父节点下标是:(i - 1) / 2

  • 它的左子节点下标是:2 * i + 1

  • 它的右子节点下标是:2 * i + 2

你看,不需要指针,不需要复杂的节点对象,一个简单的数组就能把这棵树的关系表达得明明白白。这也是堆效率高的一个原因。

堆的核心操作:上浮(heapInsert)和下沉(heapify)

一个新元素要加入堆怎么办? 很简单,两步走:

  1. 先把新元素放到数组的最后面。

  2. 然后让它和自己的父节点比较,如果它比父节点“优秀”(大根堆里它更大,小根堆里它更小),就和父节点交换。一路“上浮”,直到它不再比父节点优秀,或者自己已经到了堆顶。这个过程叫 heapInsert

那如果我把堆顶元素拿走了(比如取最大值),这个结构怎么维持? 也简单,也分两步:

  1. 把数组最后一个元素挪到堆顶的位置。

  2. 现在堆顶可能不满足堆序性了。让它和它的两个子节点中更“优秀”的那个比较,如果它不如那个子节点优秀,就交换。一路“下沉”,直到它的所有子节点都比它“弱”,或者它已经没有子节点了。这个过程叫 heapify

Java 里,PriorityQueue 就是堆的官方实现,默认是小根堆。想用大根堆?传个比较器进去就行。

// 默认小根堆 PriorityQueue<Integer> minHeap = new PriorityQueue<>();

// 实现大根堆 PriorityQueue<Integer> maxHeap = new PriorityQueue<>((a, b) -> b - a);

理论说完了,不上题就是耍流氓。我们来看一道经典题,用三种思路把它搞定。


实战演练:数组中的第 K 个最大元素

LeetCode 215. 数组中的第 K 个最大元素 给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

比如 nums = [3,2,1,5,6,4], k = 2,第二大的元素是 5,就返回 5。

解法一:暴力排序法——最无脑,也最直接

刚看到这题,脑子里第一个蹦出来的想法是什么?

找第 K 大?那我先把整个数组排个序,不就知道了吗?

思路:

  1. 对整个数组 nums 从小到大排序。

  2. 排序后,数组的最后一个元素是第 1 大的,倒数第二个是第 2 大的,那么第 k 大的元素,自然就在下标 nums.length - k 的位置上。

import java.util.Arrays;

class Solution1 {
    public int findKthLargest(int[] nums, int k) {
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            // 随便抛个异常,或者返回特定值,面试时要和面试官确认
            throw new IllegalArgumentException("Invalid input");
        }

        // 对整个数组排序
        Arrays.sort(nums);

        // 返回倒数第 k 个元素
        return nums[nums.length - k];
    }
}
  • 复杂度分析

    • 时间复杂度:O(N * logN),主要开销在排序。

    • 空间复杂度:O(logN),这是 Java Arrays.sort 内置排序(快速排序)的递归栈空间消耗。

这个解法虽然简单,但面试官肯定不满意。为啥?因为我们为了找区区一个第 K 大的数,把整个数组都给排好序了,杀鸡用牛刀,做了太多无用功。

解法二:堆/优先队列法——面试中的标准答案

既然排序太浪费,那我们能不能只维护一部分数据?这个思路就很自然地引向了堆。

我们要求第 K 大的元素。可以维护一个大小为 k 的数据容器,遍历整个数组,最终让这个容器里装着数组中最大的 k 个数。那么这 k 个数里最小的那个,不就是我们想要的答案吗?

“维护 K 个数中的最小值”,这不就是小根堆的拿手好戏吗?

思路:

  1. 搞一个容量为 k小根堆 PriorityQueue

  2. 遍历数组 nums 中的每个元素 num

  3. 如果堆的大小还不到 k,直接把 num 加进去。

  4. 如果堆的大小已经等于 k 了,就比较 num 和堆顶元素 heap.peek()

    • 如果 num 比堆顶元素还要小(或等于),说明它肯定不是前 K 大的,直接忽略。

    • 如果 num 比堆顶元素大,说明 num 有资格进入“前 K 大”的行列,而原来的堆顶元素(那 K 个数里最小的)就该被淘汰。于是,我们把堆顶 poll() 出来,再把 num 加进去。

  5. 遍历完整个数组后,堆里剩下的就是整个数组最大的 k 个数,而堆顶,就是这 k 个数里的最小值,也就是第 K 大的元素。

import java.util.PriorityQueue;

class Solution2 {
    public int findKthLargest(int[] nums, int k) {
        // 校验输入
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            throw new IllegalArgumentException("Invalid input");
        }

        // 创建一个大小为 k 的小根堆
        // Java 的 PriorityQueue 默认就是小根堆
        PriorityQueue<Integer> minHeap = new PriorityQueue<>(k);

        for (int num : nums) {
            if (minHeap.size() < k) {
                minHeap.add(num);
            } else if (num > minHeap.peek()) {
                // 当前元素比堆顶元素(k个最大数中的最小数)大
                // 弹出堆顶,加入当前元素
                minHeap.poll();
                minHeap.add(num);
            }
        }

        // 遍历结束后,堆顶就是第 k 大的元素
        return minHeap.peek();
    }
}
  • 复杂度分析

    • 时间复杂度:O(N * logK)。我们遍历了 N 个元素,每次操作(add, poll, peek)最多花费 O(logK) 的时间,因为堆的大小不会超过 k

    • 空间复杂度:O(K),堆需要存储 k 个元素。

这个解法比暴力排序好多了,尤其当 N 很大而 k 很小的时候。这通常是面试官期待你给出的答案。它精准地解决了问题,没有多余的操作。

解法三:快速选择算法——秀肌肉的时刻

如果面试官问:“还有没有更优的解法?比如时间复杂度能做到 O(N) 吗?”

这时候,如果你能答出快速选择(Quick Select)算法,那绝对是加分项。

这个算法是快速排序的“魔改版”。快排的精髓是 partition 操作:随机选一个 pivot,把数组分成三部分——< pivot 的、== pivot 的、> pivot 的。然后递归处理左右两边。

但我们找第 K 大的元素,需要递归两边吗?不需要!

思路:

  1. 我们要找的是排序后下标为 target = nums.length - k 的元素。

  2. 我们对数组进行一次 partition 操作,得到 pivot 最终停留的下标 p

  3. 比较 ptarget

    • 如果 p == target,恭喜,我们一步到位,nums[p] 就是答案。

    • 如果 p < target,说明我们要找的数在 p 的右边,我们只需要在右半部分继续找。

    • 如果 p > target,说明我们要找的数在 p 的左边,我们只需要在左半部分继续找。

  4. 每次都只处理一半的区域,规模不断减小,直到找到目标。

为了防止最坏情况(每次都选到最大或最小的 pivot,导致复杂度退化到 O(N²)),我们每次随机选一个 pivot。

import java.util.Random;

class Solution3 {
    public int findKthLargest(int[] nums, int k) {
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            throw new IllegalArgumentException("Invalid input");
        }
        // 目标索引
        int targetIndex = nums.length - k;
        return quickSelect(nums, 0, nums.length - 1, targetIndex);
    }

    private int quickSelect(int[] nums, int left, int right, int targetIndex) {
        while (left <= right) {
            // 随机选择 pivot,并交换到区域的最右边
            int pivotIndex = new Random().nextInt(right - left + 1) + left;
            swap(nums, pivotIndex, right);

            // partition 操作,返回 pivot 的最终位置
            int p = partition(nums, left, right);

            if (p == targetIndex) {
                return nums[p];
            } else if (p < targetIndex) {
                // 目标在右侧,缩小范围
                left = p + 1;
            } else { // p > targetIndex
                // 目标在左侧,缩小范围
                right = p - 1;
            }
        }
        return -1; // 理论上不会走到这里
    }

    // partition 过程,这里用荷兰国旗问题的思路
    private int partition(int[] nums, int left, int right) {
        int pivot = nums[right]; // 用区域最右边的数做 pivot
        int less = left - 1; // 小于区的右边界
        for (int i = left; i < right; i++) {
            if (nums[i] <= pivot) {
                less++;
                swap(nums, less, i);
            }
        }
        swap(nums, less + 1, right); // pivot 放到最终位置
        return less + 1;
    }

    private void swap(int[] nums, int i, int j) {
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
}
  • 复杂度分析

    • 时间复杂度:期望是 O(N)。因为每次 partition 都是 O(N) 的,但之后处理的数据规模期望减半,所以总的计算量是 N + N/2 + N/4 + ... ≈ 2N,也就是 O(N)。最坏情况(每次都选到最差的 pivot)是 O(N²),但通过随机化 pivot 可以极大概率避免。

    • 空间复杂度:O(1)。我们是原地修改数组,没有用额外的存储空间(如果递归写的话,有 O(logN) 的栈空间,我这里改成了迭代,就是 O(1))。

总结一下

从 O(NlogN) 排序,到更优的 O(NlogK) 堆,再到理论上最优的 O(N) 快速选择,一步步把解法优化到了极致。

  1. 暴力排序:简单粗暴,但性能差,通常只能作为垫底方案。

  2. :解决 Top K 问题的“万金油”,思路清晰,代码好写,性能也相当不错,是面试中最稳妥、最常见的解法。

  3. 快速选择:展现你算法功底的“杀手锏”,能想到并写出这个解法,说明你对分治思想和排序算法的理解非常深刻。

把堆这个结构彻底搞明白了,不仅是 PriorityQueue 的 API 调用,更是它底层的数组实现、上浮和下沉的操作。这样,面试遇到相关问题,就能信手拈来,游刃有余。