滚动哈希:字符串匹配的艺术

MrHe··4 min read

搞算法的,谁还没被字符串匹配问题捶过?给你一个长文本串 T 和一个模式串 P,让你找出 PT 中所有出现的位置。这问题太经典了,经典到面试官闭着眼都能问出来。

比如,T = "abababa"P = "aba",那 PT 中的起始位置就是 0, 2, 4。

拿到这种题,脑子里第一反应是啥?别想太多,先上暴力,这是对题目最基本的尊重。

解法一:简单粗暴的暴力匹配

这个思路最直接,也最符合人类的直觉。

我就把 P 当成一个尺子,从 T 的开头开始,一位一位地往后对。

  1. P 对准 T 的第 0 位,看看从 T[0] 开始长度为 len(P) 的子串是不是和 P 完全一样。

  2. 是,就记录下位置 0。

  3. 不是,就把 P 往后挪一位,对准 T 的第 1 位,再比较。

  4. 重复这个过程,直到 P 没法再往后挪了。

这个过程就像一个长度固定的滑动窗口,在 T 上滑来滑去,每次滑动都停下来做一次完整的比较。

代码写出来大概是这个样子:

import java.util.ArrayList;
import java.util.List;

public class StringMatching {

    // 解法一:暴力解
    public List<Integer> bruteForceSearch(String T, String P) {
        List<Integer> result = new ArrayList<>();
        int n = T.length();
        int m = P.length();
        if (n < m) {
            return result;
        }

        // 外层循环控制窗口的起始位置
        for (int i = 0; i <= n - m; i++) {
            // 内层循环逐个字符比较
            int j = 0;
            while (j < m) {
                if (T.charAt(i + j) != P.charAt(j)) {
                    break; // 一旦有字符不匹配,立刻中断
                }
                j++;
            }
            if (j == m) {
                // 如果 j 走到了头,说明 P 中所有字符都匹配上了
                result.add(i);
            }
        }
        return result;
    }
}

这代码没啥毛病,简单清晰。但我们得分析一下它的性能。假设 T 的长度是 N,P 的长度是 M。外层循环最多走 N-M+1 次,内层循环最坏情况下每次都要比较 M 个字符。所以总的时间复杂度是 O((N-M) * M),近似看就是 O(N*M)

N 和 M 要是都挺大,比如都是 10^5 级别,这复杂度直接就原地爆炸了,肯定不是面试官想要的答案。

那瓶颈在哪?瓶颈就在于那个 while 循环。我们每次移动窗口,都要花费 O(M) 的时间去比较两个字符串是否相等。这个比较太昂贵了。

有没有办法把这个比较操作优化到 O(1) 呢?

解法二:朴素哈希,想法很好但没啥用

要想 O(1) 比较两个东西是否相等,一个常见的骚操作就是“降维打击”——把一个复杂的东西映射成一个简单的东西,比如一个数字。

对,就是哈希。

我们可以设计一个哈希函数,把任意一个字符串,都转换成一个整数。如果两个字符串的哈希值不一样,那它们肯定不是同一个字符串。如果哈希值一样,那它们大概率是同一个字符串。

为什么是大概率?因为有哈希冲突。不过我们可以先不管冲突,假设我们的哈希函数足够牛逼,不会冲突。

那思路就变成了:

  1. 先计算出模式串 P 的哈希值 hashP

  2. 然后,在 T 上滑动窗口,每次都计算出当前窗口内子串的哈希值 hashT_sub

  3. 比较 hashT_subhashP 是否相等。如果相等,就认为找到了一个匹配。

我们来设计一个简单的哈希函数。比如,对于字符串 "aba",我们可以把它看成一个 26 进制的数(假设只有小写字母):a*26^2 + b*26^1 + a*26^0。当然,为了方便,我们可以用 ASCII 码值,就当成一个 128 进制的数。

代码大概是这样:

// 解法二的思路演示,并非完整代码
public List<Integer> naiveHashSearch(String T, String P) {
    List<Integer> result = new ArrayList<>();
    int n = T.length();
    int m = P.length();
    if (n < m) {
        return result;
    }

    // 1. 计算 P 的哈希值
    long hashP = calculateHash(P);

    // 2. 遍历 T 的所有子串
    for (int i = 0; i <= n - m; i++) {
        String sub = T.substring(i, i + m);
        // 3. 计算子串的哈希值并比较
        if (calculateHash(sub) == hashP) {
            // 如果哈希值相同,我们还应该再确认一下字符串本身是否相等,防止哈希冲突
            if (sub.equals(P)) {
                result.add(i);
            }
        }
    }
    return result;
}

// 一个简单的哈希函数
private long calculateHash(String s) {
    long hash = 0;
    long base = 31; // 选一个质数作为基数
    for (int i = 0; i < s.length(); i++) {
        hash = hash * base + s.charAt(i);
    }
    return hash;
}

我们来分析一下这个“优化”。外层循环还是 O(N)。内层呢?每次 T.substring(i, i + m) 是 O(M)(在某些版本的 Java 中是 O(1),但现代版本通常是 O(M) 的拷贝),然后 calculateHash(sub) 又是 O(M)。

兜兜转转,时间复杂度还是 O(N*M)。白忙活一场?

不,思路是对的,只是实现上太笨了。我们每次移动窗口,都重新计算了整个窗口的哈希值。但实际上,相邻的两个窗口只有一个字符的差别!

比如,从 T[i...i+M-1] 移动到 T[i+1...i+M],只是去掉了头部的 T[i],在尾部增加了一个 T[i+M]。这里面大有文章可做!

解法三:滚动哈希,让哈希值“滚动”起来!

这才是今天的主角。我们的目标,是在 O(1) 的时间内,根据上一个窗口的哈希值,计算出当前窗口的哈希值。

我们继续用刚才的进制思想。假设我们的基数是 b,窗口长度是 Mhash("abc") = a * b^2 + b * b^1 + c * b^0

现在窗口向右移动一位,变成了 "bcd"。我们希望从 hash("abc") 快速得到 hash("bcd")hash("bcd") = b * b^2 + c * b^1 + d * b^0

观察一下: hash("bcd")hash("abc") 的关系是啥?

hash("abc") - a * b^2 = b * b^1 + c * b^0 (hash("abc") - a * b^2) * b = b * b^2 + c * b^1 (hash("abc") - a * b^2) * b + d = b * b^2 + c * b^1 + d * b^0 = hash("bcd")

看出来了吗? new_hash = (old_hash - T[i] * b^(M-1)) * b + T[i+M]

这就是“滚动”的精髓!每次计算新哈希值,只需要减去最高位的影响,整体乘以基数(相当于左移一位),再加上新来的最低位。这个计算过程是 O(1) 的!

这里还有两个小问题要处理:

  1. 数值溢出:哈希值会变得非常大,long 都存不下。怎么办?取模。我们选一个大的质数 Q,所有计算都在模 Q 的意义下进行。

  2. 负数取模old_hash - T[i] * b^(M-1) 可能会是负数。在 Java 里,负数取模还是负数,这不符合我们的预期。处理方法是 (a - b) % Q 写成 (a - b % Q + Q) % Q,保证结果是正数。

好了,万事俱备,上代码!

import java.util.ArrayList;
import java.util.List;

public class RollingHashSearch {

    // 解法三:滚动哈希 (Rabin-Karp)
    public List<Integer> rollingHashSearch(String T, String P) {
        List<Integer> result = new ArrayList<>();
        int n = T.length();
        int m = P.length();
        if (n < m) {
            return result;
        }

        // 选择一个合适的基数和一个大的质数模数
        long base = 31;
        long Q = 1_000_000_007L;

        // 1. 计算 b^(m-1) % Q,用于移除最高位字符
        long h = 1;
        for (int i = 0; i < m - 1; i++) {
            h = (h * base) % Q;
        }

        // 2. 计算 P 和 T 第一个窗口的哈希值
        long hashP = 0;
        long hashT = 0;
        for (int i = 0; i < m; i++) {
            hashP = (hashP * base + P.charAt(i)) % Q;
            hashT = (hashT * base + T.charAt(i)) % Q;
        }

        // 3. 开始滑动窗口
        for (int i = 0; i <= n - m; i++) {
            // 首先比较当前窗口的哈希值
            if (hashP == hashT) {
                // 哈希值相同,再逐一比较字符串,防止哈希冲突!
                // 这一步是必须的,虽然触发概率低,但不能省略
                if (T.substring(i, i + m).equals(P)) {
                    result.add(i);
                }
            }

            // 如果窗口还没滑到头,计算下一个窗口的哈希值
            if (i < n - m) {
                // 移除最高位 T[i]
                long termToRemove = (h * T.charAt(i)) % Q;
                hashT = (hashT - termToRemove + Q) % Q; // 保证正数
                // 左移一位,并加上新来的最低位 T[i+m]
                hashT = (hashT * base + T.charAt(i + m)) % Q;
            }
        }

        return result;
    }
}

我们再来分析一下这个最终版的复杂度。

  • 计算 h (即 b^(m-1)) 是 O(M)。

  • 计算初始的 hashPhashT 是 O(M)。

  • 主循环,滑动窗口,跑了 N-M+1 次,大约 O(N) 次。

  • 循环内部,更新哈希值的操作是 O(1)。

  • 唯一可能拖后腿的是哈希冲突后的 equals 比较,最坏情况是每次都冲突,每次都比较 O(M) 次,那复杂度就退化回 O(N*M) 了。但在哈希函数设计良好的情况下(比如 baseQ 选得好),冲突概率极低,可以近似认为 equals 的总开销不大。

所以,在平均情况下,滚动哈希的时间复杂度是 O(N+M)。这不就起飞了吗!从 O(N*M) 干到了 O(N+M),这是一个质的飞跃。

总结一下

今天我们从一个经典的字符串匹配问题出发,体验了一把思维升级的过程:

  1. 暴力解:O(N*M),直观但低效,是我们思考的基石。

  2. 朴素哈希:O(N*M),引入了哈希思想,但实现上没有抓住优化的关键点,暴露了“重复计算”的瓶颈。

  3. 滚动哈希:O(N+M),抓住了“增量更新”这个核心,利用相邻窗口的高度重叠性,实现了 O(1) 的哈希值更新,最终达到了线性的时间复杂度。

这个从“暴力”到“发现瓶颈”再到“针对性优化”的思路,是算法学习里非常重要的一环。滚动哈希这个思想本身也很有用,不仅仅是字符串匹配,在很多需要对一个滑动窗口内的内容进行快速计算的场景,都能看到它的影子。