优先队列PriorityQueue

MrHe··14 min read

很多时候,我们处理数据不只是简单地存进去、取出来,而是希望每次取出的都是当前这批数据里"最"牛的那个,比如最大、最小、最重要等等。普通队列是先进先出,栈是后进先出,都满足不了这个需求。这时候,优先队列就登场了。

你可以把它想象成一个VIP候机室,不管你什么时候来的,只要你的"优先级"最高(比如是头等舱、白金会员),你就能最先登机。在算法里,这个"优先级"通常就是元素的大小。

Java里 PriorityQueue 就是它的标准实现,底层是一棵二叉堆(小顶堆)。这意味着,它能用 O(logN) 的时间复杂度完成插入和删除堆顶元素的操作,用 O(1) 的时间看到堆顶是谁。这个性质,是解决很多问题的关键。

第 9 章 优先队列

9.2 优先队列的实现

9.2.1 排序数组

在聊 PriorityQueue 的底层实现(堆)之前,咱们先用一个最朴素的思路想一想,如果让你自己实现一个优先队列,你会怎么做?

最无脑的方法,就是用一个数组。每次 add 一个新元素,我就把它加到数组末尾。每次要 peek (看优先级最高的)或者 poll (取出优先级最高的),我就得遍历整个数组,找到那个最值。这个方法,add 是 O(1),但 pollpeek 都是 O(N),太慢了,尤其是在需要频繁取最值的场景。

稍微聪明一点,我可以用一个有序数组。为了保持数组有序,每次 add 新元素,我得先找到它该插入的位置,然后把后面的元素都往后挪一位。这个 add 操作,因为要查找和移动,时间复杂度是 O(N)。好处是 peekpoll (取数组末尾或开头的元素)就变成了 O(1) 或者 O(N)(如果需要移动元素来填补空位)。

总结一下,用数组来模拟,总有一头会是 O(N),性能瓶颈很明显。而堆结构,能把 addpoll 都维持在 O(logN),这才是我们想要的平衡。

9.2.2 数组中的第 k 个最大元素

问题描述

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

示例 1: 输入: [3,2,1,5,6,4]k = 2 输出: 5

示例 2: 输入: [3,2,3,1,2,4,5,5,6]k = 4 输出: 4


思路一:排序大法

看到"第k大",最直观的想法就是,我先把整个数组排个序,那不就什么都清楚了吗?

比如 [3,2,1,5,6,4],排个序变成 [1,2,3,4,5,6]。要找第2大的元素,那就是从后往前数第2个,也就是 5。简单粗暴,逻辑清晰。

import java.util.Arrays;

class Solution1 {
    public int findKthLargest(int[] nums, int k) {
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            // 随便抛个异常,或者返回特定值,看题目要求
            throw new IllegalArgumentException("Invalid input");
        }
        Arrays.sort(nums);
        // 数组长度为 n,排序后下标是 0 到 n-1
        // 第1大是 nums[n-1]
        // 第2大是 nums[n-2]
        // ...
        // 第k大是 nums[n-k]
        return nums[nums.length - k];
    }
}

这个解法的时间复杂度主要来自排序,所以是 O(N log N),空间复杂度是 O(1) 或 O(log N),取决于排序算法的实现。面试时这么写,肯定没问题,但面试官大概率会追问:"还有没有更优的解法?"

思路二:优先队列(小顶堆)

题目要我们找"第k大"的元素。咱们换个角度想,这意味着有 k-1 个元素比它大,剩下的 N-k 个元素都比它小或等于它。

我们可以维护一个大小为 k 的集合,里面存放着到目前为止我们见过的最大的 k 个数。当我们遍历完整个数组,这个集合里最小的那个数,不就是整个数组的第 k 大元素吗?

举个例子,[3,2,1,5,6,4]k=2。 我要维护一个大小为2的集合。

  1. 先看 3,集合不满,加进去。集合:{3}

  2. 再看 2,集合不满,加进去。集合:{3, 2}

  3. 再看 1,集合满了,1 比集合里最小的 2 还小,不理它。集合:{3, 2}

  4. 再看 5,集合满了,5 比集合里最小的 2 大,把 2 踢出去,5 加进来。集合:{3, 5}

  5. 再看 6,集合满了,6 比集合里最小的 3 大,把 3 踢出去,6 加进来。集合:{5, 6}

  6. 再看 4,集合满了,4 比集合里最小的 5 小,不理它。集合:{5, 6}

遍历结束,集合里是 {5, 6},最小的是 5。所以答案就是 5

这个"维护大小为k的集合,并能快速找到最小值踢出去"的操作,不就是小顶堆的完美应用场景吗?

所以,我们的思路来了:

  1. 创建一个大小为 k 的小顶堆。

  2. 遍历数组,先把前 k 个元素加到堆里。

  3. 从第 k+1 个元素开始,如果当前元素比堆顶(也就是 k 个元素里的最小值)要大,那就把堆顶弹出来,把当前元素加进去。

  4. 遍历完整个数组后,堆顶的那个元素,就是我们想要的答案。

import java.util.PriorityQueue;

class Solution2 {
    public int findKthLargest(int[] nums, int k) {
        // Java的PriorityQueue默认就是小顶堆
        PriorityQueue<Integer> minHeap = new PriorityQueue<>(k);

        for (int num : nums) {
            if (minHeap.size() < k) {
                minHeap.add(num);
            } else if (num > minHeap.peek()) {
                minHeap.poll(); // 弹出堆顶(最小的)
                minHeap.add(num); // 加入更大的
            }
        }

        return minHeap.peek();
    }
}

这个解法的时间复杂度是 O(N log K),因为我们遍历了 N 个元素,每次对堆的操作是 log K。空间复杂度是 O(K),用来存堆里的元素。当 K 远小于 N 时,这个解法比 O(N log N) 要快。

9.2.3 相对名次

问题描述

给你一个长度为 n 的整数数组 score,其中 score[i] 是第 i 位运动员在比赛中的得分。所有得分都 互不相同

运动员将根据得分 从高到低 进行排名。排名第 1 的运动员将获得 "Gold Medal" ,排名第 2 的运动员将获得 "Silver Medal" ,排名第 3 的运动员将获得 "Bronze Medal" 。从排名第 4 到第 n 的运动员,只能获得他们的排名编号(譬如,排名第 x 的运动员获得编号 "x")。

使用长度为 n 的数组 answer 返回获奖情况,其中 answer[i] 是第 i 位运动员的获奖情况。

示例 1: 输入:score = [5,4,3,2,1] 输出:["Gold Medal","Silver Medal","Bronze Medal","4","5"]

示例 2: 输入:score = [10,3,8,9,4] 输出:["Gold Medal","5","Bronze Medal","Silver Medal","4"]


思路一:排序 + 哈希表

这个问题,本质上就是要给每个原始分数找到它在排序后的名次。但麻烦的是,最后的结果要按照原始输入的顺序来排列。

这就意味着,我既需要分数排序后的信息,又不能丢失分数和它原始位置的对应关系。

一个直接的想法是,用一个哈希表(或者一个二维数组)来存 (分数, 原始下标) 这个映射。然后我对这个结构按照分数进行降序排序。排序后,我就可以依次给它们赋予名次了。

  1. 创建一个二维数组 pairspairs[i][0] 存分数,pairs[i][1] 存原始下标。

  2. 遍历 score 数组,填充 pairs

  3. pairs 按照分数(第一维)降序排序。

  4. 创建一个结果数组 answer

  5. 遍历排序后的 pairs,根据名次 (i+1),在 answer 数组的正确位置(pairs[i][1])填入相应的奖牌或名次字符串。

import java.util.Arrays;
import java.util.Comparator;

class Solution1 {
    public String[] findRelativeRanks(int[] score) {
        int n = score.length;
        int[][] pairs = new int[n][2];
        for (int i = 0; i < n; i++) {
            pairs[i][0] = score[i]; // 分数
            pairs[i][1] = i;        // 原始索引
        }

        // 按分数降序排序
        Arrays.sort(pairs, (a, b) -> b[0] - a[0]);

        String[] answer = new String[n];
        for (int i = 0; i < n; i++) {
            int originalIndex = pairs[i][1];
            if (i == 0) {
                answer[originalIndex] = "Gold Medal";
            } else if (i == 1) {
                answer[originalIndex] = "Silver Medal";
            } else if (i == 2) {
                answer[originalIndex] = "Bronze Medal";
            } else {
                answer[originalIndex] = String.valueOf(i + 1);
            }
        }
        return answer;
    }
}

这个方法的时间复杂度是 O(N log N)(排序),空间复杂度是 O(N)(pairs 数组)。

思路二:优先队列(大顶堆)

思路一的本质是排序。凡是涉及到排序找极值的,都可以想想优先队列。

这个题要求从高到低排名,那我就用一个大顶堆。同样,为了不丢失原始位置信息,我堆里存的不能只是分数,而应该是 (分数, 原始下标) 对。

  1. 创建一个大顶堆,比较器按照分数来比较。

  2. 遍历 score 数组,把 (score[i], i) 这样的数对全部加入大顶堆。

  3. 创建一个结果数组 answer

  4. 循环 n 次,每次从大顶堆里 poll 出一个元素。这个元素就是当前未排名里分数最高的。

  5. 第一次取出的就是第一名,第二次就是第二名,以此类推。根据名次和取出的元素的原始下标,填充 answer 数组。

import java.util.PriorityQueue;

class Solution2 {
    public String[] findRelativeRanks(int[] score) {
        int n = score.length;
        // 创建一个大顶堆,元素是 int[2],int[0]是分数,int[1]是原始索引
        // 比较器 (a, b) -> b[0] - a[0] 表示按分数降序排列
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>((a, b) -> b[0] - a[0]);

        for (int i = 0; i < n; i++) {
            maxHeap.add(new int[]{score[i], i});
        }

        String[] answer = new String[n];
        int rank = 1;
        while (!maxHeap.isEmpty()) {
            int[] top = maxHeap.poll();
            int originalIndex = top[1];

            if (rank == 1) {
                answer[originalIndex] = "Gold Medal";
            } else if (rank == 2) {
                answer[originalIndex] = "Silver Medal";
            } else if (rank == 3) {
                answer[originalIndex] = "Bronze Medal";
            } else {
                answer[originalIndex] = String.valueOf(rank);
            }
            rank++;
        }
        return answer;
    }
}

这个解法,建堆的过程是 O(N log N),然后取出 n 个元素,每次也是 O(log N),所以总的时间复杂度还是 O(N log N)。空间复杂度是 O(N) 用来存堆。

看起来和思路一复杂度一样,但优先队列提供了一种更动态的思路,尤其适合那些"我不需要一次性知道所有排名,而是要一个一个地按顺序处理最高分"的场景。

9.3 优先队列的应用算法设计

前面两个问题算是热身,让我们熟悉了优先队列的基本用法。接下来,咱们看几个更复杂的场景,看看优先队列是怎么大显神通的。

9.3.1 数据流中的第 k 大元素

问题描述

设计一个找到数据流中第 k 大元素的类(class)。注意是排序后的第 k 大元素,不是第 k 个不同的元素。

请实现 KthLargest 类:

  • KthLargest(int k, int[] nums) 使用整数 k 和整数流 nums 初始化对象。

  • int add(int val)val 加入数据流,并返回数据流中第 k 大的元素。

示例:

输入:
["KthLargest", "add", "add", "add", "add", "add"]
[[3, [4, 5, 8, 2]], [3], [5], [10], [9], [4]]
输出:
[null, 4, 5, 5, 8, 8]

解释:
KthLargest kthLargest = new KthLargest(3, [4, 5, 8, 2]);
kthLargest.add(3);   // return 4
kthLargest.add(5);   // return 5
kthLargest.add(10);  // return 5
kthLargest.add(9);   // return 8
kthLargest.add(4);   // return 8

这道题和 9.2.2 非常像,但加了一个 "数据流" 的概念。这意味着数据是动态增加的,我们不能每次 add 都把所有数据重新排序一遍,那效率太低了。

思路一:暴力法(每次都排序)

为了对比,我们还是先想一个最笨的方法。 我用一个 ArrayList 存所有的数据。 构造函数里,把 nums 全加到 list 里。 每次调用 add(val),我就把 val 也加到 list 里,然后对整个 list 排序,返回倒数第 k 个元素。

这个思路的问题非常明显,add 操作太重了。如果数据流非常大,每次都 O(N log N) 地排序,性能会急剧下降。这显然不是面试官想要的答案。

// 仅作思路对比,不推荐
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

class KthLargest1 {
    private int k;
    private List<Integer> list;

    public KthLargest1(int k, int[] nums) {
        this.k = k;
        this.list = new ArrayList<>();
        for (int num : nums) {
            list.add(num);
        }
    }

    public int add(int val) {
        list.add(val);
        Collections.sort(list);
        return list.get(list.size() - k);
    }
}

思路二:小顶堆(正解)

"数据流"、"第k大",这两个关键词凑在一起,就是赤裸裸地在提示我们用优先队列。

思路跟 9.2.2 的解法完全一样:我们始终维护一个大小为 k 的小顶堆。这个堆里存的就是当前数据流中最大的 k 个数。那么,堆顶元素,自然就是这 k 个数里的最小值,也就是整个数据流中的第 k 大元素。

  1. 构造函数 KthLargest(k, nums):

    • 初始化一个大小为 k 的小顶堆。

    • 遍历初始数组 nums,对每个元素调用 add 方法,来构建好初始的堆。

  2. add(val) 方法:

    • 如果当前堆的大小还不到 k,直接把 val 加进去。

    • 如果堆的大小已经是 k了,就比较 val 和堆顶元素 heap.peek()

      • 如果 val 比堆顶元素大,说明 val 有资格进入"前k大"的行列,而原来的堆顶(第k大)就要被淘汰了。所以,poll() 出堆顶,再 add(val)

      • 如果 val 小于或等于堆顶元素,那它连第 k 大都算不上,直接忽略。

    • 经过这些操作后,heap.peek() 就是当前数据流中的第k大元素,返回它即可。

import java.util.PriorityQueue;

class KthLargest {
    private final PriorityQueue<Integer> minHeap;
    private final int k;

    public KthLargest(int k, int[] nums) {
        this.k = k;
        this.minHeap = new PriorityQueue<>(k);
        for (int num : nums) {
            add(num); // 复用add方法的逻辑来初始化
        }
    }

    public int add(int val) {
        if (minHeap.size() < k) {
            minHeap.add(val);
        } else if (val > minHeap.peek()) {
            minHeap.poll();
            minHeap.add(val);
        }
        return minHeap.peek();
    }
}

这个解法,构造函数的时间复杂度是 O(N log K),add 方法的复杂度是 O(log K),空间复杂度是 O(K)。完美解决了数据流问题,每次新来数据都能高效地给出答案。

9.3.2 查找和最小的 k 对数字

问题描述

给定两个以升序排列的整数数组 nums1nums2 , 以及一个整数 k

定义一对值 (u,v),其中第一个元素来自 nums1,第二个元素来自 nums2

请找到和最小的 k 个数对 (u,v)

示例 1: 输入: nums1 = [1,7,11], nums2 = [2,4,6], k = 3 输出: [1,2],[1,4],[1,6] 解释: 返回最小的 k 个数对。 [1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]


思路一:暴力全组合 + 排序

最直接的想法,就是把所有可能的数对都生成出来,计算它们的和,然后排个序,取前 k 个。

nums1 长度为 mnums2 长度为 n,总共有 m*n 个数对。

  1. 创建一个列表,用来存放所有数对。

  2. 双重循环遍历 nums1nums2,生成所有 (u, v) 对,加入列表。

  3. 对列表按照数对的和进行排序。

  4. 取出列表的前 k 个元素作为结果。

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

class Solution1 {
    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
        List<List<Integer>> allPairs = new ArrayList<>();
        for (int u : nums1) {
            for (int v : nums2) {
                List<Integer> pair = new ArrayList<>();
                pair.add(u);
                pair.add(v);
                allPairs.add(pair);
            }
        }

        // 按和排序
        allPairs.sort((a, b) -> (a.get(0) + a.get(1)) - (b.get(0) + b.get(1)));

        List<List<Integer>> result = new ArrayList<>();
        for (int i = 0; i < Math.min(k, allPairs.size()); i++) {
            result.add(allPairs.get(i));
        }
        return result;
    }
}

这个方法的时间复杂度是 O(MN log(MN)),因为要生成并排序 MN 个元素。空间复杂度是 O(MN)。当 M 和 N 很大时,这个方法会超时或内存溢出。

思路二:优先队列(多路归并思想)

上面的暴力法问题在于,我们生成了太多不必要的数对。题目只要前 k 小的,我们可能只需要探索一小部分组合就行了。

注意到 nums1nums2 都是升序的。那么和最小的数对一定是 (nums1[0], nums2[0])。 第二小的呢?可能是 (nums1[1], nums2[0]) 或者 (nums1[0], nums2[1])

这就像合并 m 个有序链表一样。我们可以把问题看成这样: 对于 nums1 中的每一个 u = nums1[i],都有一条与之对应的有序"链表":(u, nums2[0]), (u, nums2[1]), (u, nums2[2]), ...。 我们现在就是要在这 m 条"链表"中,找出和最小的 k 个元素。

这不就是 "合并k个有序链表" 的变种吗?优先队列是解决这类问题的神器。

  1. 创建一个小顶堆,堆里存放的是数对 (u, v) 以及它们在 nums2 中的下标,比如一个数组 [u, v, v_idx]。堆的排序规则是按照数对的和 u+v

  2. 初始时,将 (nums1[i], nums2[0]) for i from 0 to m-1(或者 min(k, m-1),一个优化)都加入堆。更具体地说,是把 [nums1[i], nums2[0], 0] 加进去。

  3. 循环 k 次(或者直到堆为空):

    • 从堆中 poll 出和最小的元素 [u, v, v_idx]

    • (u, v) 加入结果集。

    • 如果这个 v 不是它所在"链表"的最后一个元素(即 v_idx + 1 < n),那么就把它的下一个元素 (u, nums2[v_idx+1]) 加入堆中。也就是把 [u, nums2[v_idx+1], v_idx+1] 加进去。

这样,我们每次都从所有"候选"的数对中取出最小的那个,并把它的"继任者"加入候选池。这就保证了我们能按和从小到大的顺序,依次找到这 k 个数对。

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

class Solution2 {
    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
        // 小顶堆,存的是一个数组 int[]{nums1元素, nums2元素, nums2的索引}
        // 排序规则是按照前两个元素的和
        PriorityQueue<int[]> minHeap = new PriorityQueue<>((a, b) -> (a[0] + a[1]) - (b[0] + b[1]));

        // 初始化堆,将 nums1[i] 和 nums2[0] 的组合入堆
        // 优化:nums1 最多只需要考虑前 k 个元素,因为最终只要 k 对
        for (int i = 0; i < Math.min(nums1.length, k); i++) {
            minHeap.add(new int[]{nums1[i], nums2[0], 0});
        }

        List<List<Integer>> result = new ArrayList<>();
        while (k > 0 && !minHeap.isEmpty()) {
            int[] top = minHeap.poll();
            int u = top[0];
            int v = top[1];
            int v_idx = top[2];

            List<Integer> pair = new ArrayList<>();
            pair.add(u);
            pair.add(v);
            result.add(pair);
            k--;

            // 如果 v 后面在 nums2 中还有元素,将下一个组合入堆
            if (v_idx + 1 < nums2.length) {
                minHeap.add(new int[]{u, nums2[v_idx + 1], v_idx + 1});
            }
        }

        return result;
    }
}

这个解法中,堆的大小最多是 k(或 m)。我们总共要进行 kpolladd 操作。 时间复杂度是 O(k log k) 或 O(k log m),空间复杂度是 O(k) 或 O(m)。这比暴力法好太多了。

9.3.3 合并 k 个有序链表

问题描述

给你一个链表数组 lists ,每个链表都已经按升序排列。

请你将所有链表合并到一个升序排列的链表中,并返回合并后的链表。

示例 1: 输入:lists = [[1,4,5],[1,3,4],[2,6]] 输出:[1,1,2,3,4,4,5,6] 解释:链表数组如下: [ 1->4->5, 1->3->4, 2->6 ] 将它们合并到一个有序链表中得到。 1->1->2->3->4->4->5->6


思路一:两两合并

这是一个很自然的分治思想。我有 k 个链表,我可以先把第1个和第2个合并,得到一个新链表;然后把这个新链表和第3个合并;再和第4个合并……直到全部合并完。

class ListNode {
    int val;
    ListNode next;
    ListNode() {}
    ListNode(int val) { this.val = val; }
    ListNode(int val, ListNode next) { this.val = val; this.next = next; }
}

class Solution1 {
    public ListNode mergeKLists(ListNode[] lists) {
        if (lists == null || lists.length == 0) {
            return null;
        }
        ListNode mergedList = lists[0];
        for (int i = 1; i < lists.length; i++) {
            mergedList = mergeTwoLists(mergedList, lists[i]);
        }
        return mergedList;
    }

    private ListNode mergeTwoLists(ListNode l1, ListNode l2) {
        ListNode dummy = new ListNode(-1);
        ListNode current = dummy;
        while (l1 != null && l2 != null) {
            if (l1.val <= l2.val) {
                current.next = l1;
                l1 = l1.next;
            } else {
                current.next = l2;
                l2 = l2.next;
            }
            current = current.next;
        }
        current.next = (l1 != null) ? l1 : l2;
        return dummy.next;
    }
}

mergeTwoLists 的时间复杂度是O(len1 + len2)。假设每个链表平均长度是 n,总节点数是 N = k*n。 第一次合并,长度 n+n = 2n。 第二次合并,长度 2n+n = 3n。 ... 总的时间复杂度是 O(2n + 3n + ... + kn) = O(n * k^2)。 如果链表长度不均,这个方法性能会比较差。

思路二:优先队列

这才是这个问题的标准解法。 想象一下,最终合并后的大链表,它的头节点,一定是 k 个链表头节点里最小的那个。 取出了这个最小的头节点后,第二小的节点,就是剩下 k-1 个头节点,和那个被取出节点的下一个节点,这 k 个节点里的最小值。

这个过程,我们其实是在 k 个"来源"中,不断地寻找当前最小的元素。这不就是优先队列的用武之地吗?

  1. 创建一个小顶堆,堆里存放的是 ListNode 节点。我们需要自定义比较器,按节点的 val 来排序。

  2. 遍历 lists 数组,把每个链表的头节点(如果不为 null)都加入到小顶堆里。

  3. 创建一个虚拟头节点 dummy 和一个指针 current,用来构建新的合并链表。

  4. 当堆不为空时,循环执行:

    • 从堆中 poll() 出值最小的节点 minNode

    • minNode 接到 current 的后面,current 指针后移。

    • 如果 minNode 还有下一个节点 minNode.next,就把它的下一个节点加入堆中。

  5. 循环结束后,dummy.next 就是合并后链表的头节点。

import java.util.PriorityQueue;

class Solution2 {
    public ListNode mergeKLists(ListNode[] lists) {
        if (lists == null || lists.length == 0) {
            return null;
        }

        // 小顶堆,按节点的值排序
        PriorityQueue<ListNode> minHeap = new PriorityQueue<>((a, b) -> a.val - b.val);

        // 初始化堆,加入所有链表的头节点
        for (ListNode head : lists) {
            if (head != null) {
                minHeap.add(head);
            }
        }

        ListNode dummy = new ListNode(-1);
        ListNode current = dummy;

        while (!minHeap.isEmpty()) {
            ListNode minNode = minHeap.poll();
            current.next = minNode;
            current = current.next;

            // 如果这个最小节点后面还有节点,把它加回堆里
            if (minNode.next != null) {
                minHeap.add(minNode.next);
            }
        }

        return dummy.next;
    }
}

假设总共有 N 个节点,k 个链表。 建堆的过程,把 k 个头节点放进去,时间是 O(k log k)。 之后,每个节点都会进堆一次,出堆一次。总共有 N 个节点,所以总的操作次数是 N 次。每次操作堆的复杂度是 O(log k),因为堆的大小最多是 k。 所以总的时间复杂度是 O(N log k),空间复杂度是 O(k) 用于存堆。这比 O(N*k) 或 O(n * k^2) 要好得多。

9.3.4 滑动窗口最大值

问题描述

给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。

返回滑动窗口中的最大值。

示例 1: 输入:nums = [1,3,-1,-3,5,3,6,7], k = 3 输出:[3,3,5,5,6,7]


思路一:暴力法

最简单的方法,就是模拟窗口的滑动。 窗口每滑动一次,我就遍历一遍窗口里的 k 个元素,找到最大值,存到结果里。

import java.util.ArrayList;
import java.util.List;

class Solution1 {
    public int[] maxSlidingWindow(int[] nums, int k) {
        if (nums == null || nums.length == 0 || k <= 0) {
            return new int[0];
        }
        int n = nums.length;
        int[] result = new int[n - k + 1];

        for (int i = 0; i <= n - k; i++) {
            int max = Integer.MIN_VALUE;
            // 遍历当前窗口 [i, i+k-1]
            for (int j = i; j < i + k; j++) {
                max = Math.max(max, nums[j]);
            }
            result[i] = max;
        }
        return result;
    }
}

这个方法非常直观,但效率不高。有 n-k+1 个窗口,每个窗口都遍历 k 次,总的时间复杂度是 O((n-k+1) * k),约等于 O(n*k)。

思路二:优先队列(大顶堆)

暴力法的问题在于,窗口滑动时,很多信息被浪费了。比如 [1, 3, -1] -> [3, -1, 5],其实 3-1 我们都已经看过了,没必要重新比较。

我可以用一个数据结构来维护当前窗口内的元素,并能快速告诉我最大值。大顶堆正好能干这个事。

思路是这样的:

  1. 维护一个大顶堆。

  2. 窗口向右滑动,每进入一个新元素,就把它加到堆里。

  3. 同时,要从堆里移除那个滑出窗口的元素。

  4. 堆顶永远是当前窗口的最大值。

这里有个问题:Java 的 PriorityQueue 提供了 addpoll(删除堆顶),但没有提供高效的删除任意元素的方法。remove(Object) 方法的时间复杂度是 O(N),这里 N 是堆的大小,也就是 k。

如果我们硬要用 PriorityQueue,那么每次窗口滑动:add 一个新元素(O(log k)),remove 一个旧元素(O(k))。总的时间复杂度还是 O(n*k),没有本质提升。

但我们可以变通一下。我不真的从堆里删除滑出去的元素,而是让它们"过期"。 我往堆里存的不再是数字,而是 (数字, 下标) 对。

  1. 创建一个大顶堆,存 int[]{value, index}

  2. 初始化:先把前 k 个元素和它们的下标加到堆里。此时堆顶就是第一个窗口的最大值。

  3. 开始滑动:从第k个元素开始遍历数组。

    • add 新元素 (nums[i], i) 进堆。

    • 关键:查看堆顶。如果堆顶元素的下标已经不在当前窗口 [i-k+1, i] 的范围内了,说明它是"过期"的最大值,poll() 掉它。重复这个操作,直到堆顶的元素是有效的。

    • 此时,堆顶就是当前窗口的最大值,记录下来。

import java.util.PriorityQueue;

class Solution2 {
    public int[] maxSlidingWindow(int[] nums, int k) {
        int n = nums.length;
        int[] result = new int[n - k + 1];
        // 大顶堆,存 int[]{value, index}
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>((a, b) -> b[0] - a[0]);

        for (int i = 0; i < n; i++) {
            // 加入新元素
            maxHeap.add(new int[]{nums[i], i});

            // 当窗口形成后,开始记录最大值
            if (i >= k - 1) {
                // 移除堆顶过期的元素
                while (maxHeap.peek()[1] <= i - k) {
                    maxHeap.poll();
                }
                // 此时堆顶就是当前窗口的最大值
                result[i - k + 1] = maxHeap.peek()[0];
            }
        }
        return result;
    }
}

这个解法,每个元素进堆一次,出堆一次。总的时间复杂度是 O(N log K),空间复杂度是 O(K)。这比暴力法好。

补充:最优解 - 单调双端队列 (Deque)

虽然优先队列能优化,但这题的最优解其实是用双端队列(Deque)。思路是维护一个单调递减的队列,队列里存的是元素的下标。

  1. 遍历数组,对于每个元素nums[i]

  2. 从队尾开始,把所有 nums 值小于 nums[i] 的下标都弹出去。这保证了队列的单调性。

  3. i 加到队尾。

  4. 检查队首的下标是否已经滑出窗口,如果是就从队首弹出。

  5. 当窗口形成后,队首的下标对应的 nums 值就是当前窗口的最大值。

每个元素最多进队一次,出队一次,所以时间复杂度是 O(N),空间复杂度是 O(K)。这个才是这道题的终极解法。这里就不展开代码了,但知道这个思路很重要。

9.3.5 最大的团队表现值

问题描述

公司有 n 名工程师。给你两个数组 speedefficiency ,长度都是 n ,其中 speed[i]efficiency[i] 分别代表第 i 名工程师的速度和效率。

请你从中选择 最多 k 名工程师,组成一个团队,使得团队的 表现值 最大。

团队的 表现值 定义为:团队中所有工程师速度的总和 * 团队中所有工程师效率的最小值

请你返回一个整数,表示团队的最大表现值。由于答案可能很大,请你返回结果对 10^9 + 7 取余后的结果。

示例 1: 输入:n = 6, speed = [2,10,3,1,5,8], efficiency = [5,4,3,9,7,2], k = 2 输出:60 解释: 我们选择工程师 2(速度=10,效率=4)和工程师 5(速度=5,效率=7)。他们的团队表现值为 (10 + 5) * min(4, 7) = 15 * 4 = 60 。


这题有点绕,表现值的计算方式是 sum(speed) * min(efficiency)

思路一:暴力枚举 枚举所有大小从1到k的工程师组合,计算每个组合的表现值,取最大。组合数太多了,C(n, k) 级别,直接 pass。

思路二:排序 + 优先队列

公式里的 min(efficiency) 是个瓶颈。如果我们能确定一个团队的 min(efficiency),问题就变成了:在所有效率不低于这个值的工程师里,选出最多 k 个速度最快的,让 sum(speed) 最大。

这个思路启发了我们。我们可以遍历每个工程师,假设他就是那个效率最低的。

  1. 把工程师们按照 效率降序 排序。这样做的好处是,当我遍历到第 i 个工程师时,所有在他之前的工程师,效率都比他高(或相等)。

  2. 维护一个大小不超过 k 的小顶堆,用来存放我们当前考虑的团队成员的 速度

  3. 遍历排序后的工程师:

    • 对于当前工程师 i,他的效率是 E_i。我们把他看作是当前团队的效率瓶颈。

    • 把他的速度 S_i 加入小顶堆,同时累加到总速度 sumSpeed 中。

    • 如果此时堆的大小超过了 k,说明团队超员了。我们必须踢掉一个人。为了让 sumSpeed 尽可能大,我们应该踢掉那个速度最慢的。小顶堆的堆顶正好就是速度最慢的那个,poll() 出来,并从 sumSpeed 中减去。

    • 现在,堆里就是 k 个(或更少)在效率不低于 E_i 的工程师里速度最快的。计算当前表现值 sumSpeed * E_i,和全局最大值 maxPerformance 比较并更新。

  4. 遍历结束后,maxPerformance 就是答案。

import java.util.Arrays;
import java.util.PriorityQueue;

class Solution {
    public int maxPerformance(int n, int[] speed, int[] efficiency, int k) {
        // 创建工程师对象,方便绑定 speed 和 efficiency
        int[][] engineers = new int[n][2];
        for (int i = 0; i < n; i++) {
            engineers[i][0] = speed[i];
            engineers[i][1] = efficiency[i];
        }

        // 按效率降序排序
        Arrays.sort(engineers, (a, b) -> b[1] - a[1]);

        // 小顶堆,存速度
        PriorityQueue<Integer> minHeap = new PriorityQueue<>(k);
        long sumSpeed = 0;
        long maxPerformance = 0;
        long MOD = 1_000_000_007;

        for (int[] engineer : engineers) {
            int currentSpeed = engineer[0];
            int currentEfficiency = engineer[1];

            // 队员入队
            minHeap.add(currentSpeed);
            sumSpeed += currentSpeed;

            // 如果队伍超员,踢掉速度最慢的
            if (minHeap.size() > k) {
                sumSpeed -= minHeap.poll();
            }

            // 计算当前团队的表现值,并更新最大值
            maxPerformance = Math.max(maxPerformance, sumSpeed * currentEfficiency);
        }

        return (int) (maxPerformance % MOD);
    }
}

时间复杂度:排序是 O(N log N),遍历是 O(N log K),所以总体是 O(N log N)。 空间复杂度:O(N) 用于存工程师数组,O(K) 用于优先队列。

9.3.6 雇佣 k 位工人的总代价

问题描述

给你一个下标从 0 开始的整数数组 costs ,其中 costs[i] 是雇佣第 i 位工人的代价。

同时给你两个整数 kcandidates 。我们想根据以下规则恰好雇佣 k 位工人:

  • 总共进行 k 轮雇佣,且每一轮恰好雇佣一位工人。

  • 在每一轮雇佣中,从所有可雇佣的工人中选择一位代价最小的工人。如果有多位代价相同的最小工人,选择其中下标最小的一位。

  • 一位工人可被雇佣,当且仅当该工人的下标在 [0, candidates - 1][n - candidates, n - 1] 的范围内。

  • 每一轮雇佣选择了一位工人后,就不再考虑该工人了。

如果在某一轮雇佣中,[0, candidates - 1][n - candidates, n - 1] 两个范围有重叠,那么我们构造成一个范围,并从这个范围的所有工人中选择一位代价最小的工人。

返回雇佣恰好 k 位工人的总代价。

示例 1: 输入:costs = [17,12,10,2,7,2,11,20,8], k = 3, candidates = 4 输出:11


思路一:模拟 + 每次扫描

这个思路最直接。总共要雇 k 轮,那我就模拟这 k 轮。 每一轮:

  1. 确定左右两个候选区间的范围。

  2. 扫描这两个区间里的所有未被雇佣的工人。

  3. 找到其中代价最小的那个(代价相同,下标小者优先)。

  4. 把他的代价加到总代价里,并把他标记为"已雇佣"。

这个方法需要一个 boolean 数组来标记工人是否被雇佣。每一轮都要扫描最多 2 * candidates 个人,总共 k 轮。 时间复杂度是 O(k * candidates),如果 candidates 很大,这个方法会超时。

思路二:双优先队列

每次都在两个候选区里找最小值,这个操作可以用优先队列来优化。

我们可以创建两个小顶堆:leftHeaprightHeapleftHeap 存放左边候选区 [0, candidates-1] 的工人信息。 rightHeap 存放右边候选区 [n-candidates, n-1] 的工人信息。

堆里存的元素是 (cost, index),排序规则是先按 cost 从小到大,cost 相同再按 index 从小到大。

  1. 初始化两个指针 left = 0, right = n-1

  2. 初始化堆

    • 从左边取 candidates 个工人 (costs[i], i) 加入 leftHeapleft 指针前进到 candidates

    • 从右边取 candidates 个工人 (costs[i], i) 加入 rightHeapright 指针后退到 n-1-candidates。注意,要保证左右两边不重叠,即 left <= right

  3. 进行 k 轮选择

    • 比较 leftHeap.peek()rightHeap.peek()

    • 选择代价更小(或下标更小)的那个工人。假设从 leftHeap 选出了。

    • 把他的代价累加到总和 totalCost

    • leftHeappoll() 出这个工人。

    • 补充新人:因为左边少了一个人,如果 left <= right 还成立,说明中间还有未被考虑的工人,就把 (costs[left], left) 加入 leftHeapleft++

    • 如果从 rightHeap 选出,则做对称的操作。

  4. 循环 k 次后,返回 totalCost

import java.util.PriorityQueue;

class Solution {
    public long totalCost(int[] costs, int k, int candidates) {
        int n = costs.length;
        // 小顶堆,存 int[]{cost, index}
        // 按 cost 升序,cost 相同按 index 升序
        PriorityQueue<int[]> leftHeap = new PriorityQueue<>((a, b) -> {
            if (a[0] != b[0]) {
                return a[0] - b[0];
            }
            return a[1] - b[1];
        });

        PriorityQueue<int[]> rightHeap = new PriorityQueue<>((a, b) -> {
            if (a[0] != b[0]) {
                return a[0] - b[0];
            }
            return a[1] - b[1];
        });

        int left = 0;
        int right = n - 1;

        // 初始化左堆
        for (int i = 0; i < candidates; i++) {
            if (left <= right) {
                leftHeap.add(new int[]{costs[left], left});
                left++;
            }
        }

        // 初始化右堆
        for (int i = 0; i < candidates; i++) {
            if (left <= right) {
                rightHeap.add(new int[]{costs[right], right});
                right--;
            }
        }

        long totalCost = 0;
        for (int i = 0; i < k; i++) {
            int[] leftCand = leftHeap.isEmpty() ? null : leftHeap.peek();
            int[] rightCand = rightHeap.isEmpty() ? null : rightHeap.peek();

            if (leftCand != null && (rightCand == null || 
                leftCand[0] < rightCand[0] || 
                (leftCand[0] == rightCand[0] && leftCand[1] < rightCand[1]))) {

                totalCost += leftHeap.poll()[0];
                if (left <= right) {
                    leftHeap.add(new int[]{costs[left], left});
                    left++;
                }

            } else if (rightCand != null) {
                totalCost += rightHeap.poll()[0];
                if (left <= right) {
                    rightHeap.add(new int[]{costs[right], right});
                    right--;
                }
            } else {
                // 如果两个堆都空了但k还没到,说明工人不够
                break;
            }
        }
        return totalCost;
    }
}

时间复杂度分析: 初始化堆是 O(candidates * log(candidates))。 之后进行 k 轮,每轮从堆里取一个,再加一个,都是 O(log(candidates))。 所以总时间复杂度是 O(k * log(candidates) + candidates * log(candidates)),简化为 O((k + candidates) * log(candidates))。 空间复杂度是 O(candidates) 用于两个堆。这比 O(k * candidates) 的模拟法要高效得多。