前缀树Idea

MrHe··10 min read

Part 1: 前缀树(Trie)的核心思想与应用

1. 什么是前缀树?

前缀树,又叫字典树单词查找树,英文名是 Trie。顾名思义,它是一种专门用来处理字符串的树形结构。为什么叫前缀树呢?因为这棵树的每条从根节点到任意节点的路径,都代表了一个字符串前缀

想象一下我们查英文字典的过程:要查 "car",我们先在 'c' 开头的区域找,然后找第二个字母 'a',最后找 'r'。前缀树就是把这个过程数据结构化了。

核心性质:

  1. 根节点不代表任何字符,是所有字符串的“共同祖先”。

  2. 从根节点到某个节点的路径上经过的字符连接起来,就是该节点对应的字符串前缀。

  3. 树的每一条边都代表一个字符。

  4. 每个节点除了有指向子节点的指针外,通常还有一个标记,用来标识从根到这个节点的路径是否构成了一个完整的单词。

举个例子,我们要存入 "cat", "car", "can", "dog"。这棵树看起来是这样的:

      (root)
      /    \
     c      d
     |      |
     a      o
     |      |
     n(end) g(end)
     |
     t(end)
     |
     r(end)

我们可以看到,"ca" 是 "cat" 和 "car" 的共同前缀,所以它们共享从 (root) -> c -> a 这条路径。

2. 前缀树的应用场景

前缀树的核心优势在于利用字符串的公共前缀来节约存储空间和查询时间。它的应用场景非常广泛:

  1. 字符串检索 / 自动补全:这是最经典的应用。输入一个前缀,就能快速找出所有以此为前缀的单词。比如搜索引擎的搜索建议、输入法联想词等。

  2. 词频统计:遍历一个文章,把每个单词插入前缀树。节点上可以增加一个计数器,insert 一次就加一,可以高效统计词频。

  3. 字符串排序:把所有字符串插入前缀树,然后对树进行一次深度优先遍历(DFS),就能得到按字典序排序的结果。

  4. 解决特定算法问题

    • 最大异或对问题(LeetCode 421):把数字的二进制形式看作成01字符串,在前缀树上进行贪心查找。

    • 涉及大量单词匹配的问题(LeetCode 212):先将所有目标单词构建成一棵前缀树,然后在二维网格上进行DFS搜索时,可以沿着前缀树的路径进行,极大地提高剪枝效率。


Part 2: 前缀树的实现方式

咱们来看怎么用代码实现它。主要有三种方式:类描述、静态数组和哈希表优化。

1. 类描述方式实现(最常用,最直观)

这是最符合面向对象思想的实现方式,结构清晰。我们首先需要一个节点类 TrieNode

一个节点需要包含什么信息?

  1. pass 计数:有多少个单词经过了这个节点。

  2. end 计数:有多少个单词以这个节点结尾。

  3. 指向下一个节点的引用:如果只考虑26个小写字母,我们可以用一个长度为26的数组 nexts 来存储。nexts[0] 指向代表 'a' 的子节点,nexts[1] 指向 'b',以此类推。

代码实现:

/**
 * 前缀树节点类(左程云风格)
 */
class TrieNode {
    // pass 表示有多少个单词经过了这个节点
    public int pass;
    // end 表示有多少个单词以这个节点结尾
    public int end;
    // nexts[i] == null 表示没有走向字符 i 的路径
    // nexts[i] != null 表示存在走向字符 i 的路径
    // 这里的 i 是通过 'c' - 'a' 计算出来的
    public TrieNode[] nexts;
​
    public TrieNode() {
        pass = 0;
        end = 0;
        // 假设只包含 a-z 这26个小写字母
        // 如果字符种类很多,这里可以换成哈希表
        nexts = new TrieNode[26];
    }
}
​
/**
 * 前缀树主体类
 */
public class Trie {
    private TrieNode root;
​
    public Trie() {
        root = new TrieNode();
    }
​
    /**
     * 向前缀树中插入一个单词
     * @param word 待插入的单词
     */
    public void insert(String word) {
        if (word == null) {
            return;
        }
        char[] chars = word.toCharArray();
        TrieNode node = root;
        // 根节点的 pass 值每次都增加
        node.pass++;
        for (char ch : chars) {
            int path = ch - 'a'; // 计算字符对应的索引
            // 如果该路径不存在,则创建新节点
            if (node.nexts[path] == null) {
                node.nexts[path] = new TrieNode();
            }
            // 移动到下一个节点
            node = node.nexts[path];
            // 沿途节点的 pass 值都增加
            node.pass++;
        }
        // 单词遍历结束,在最后一个节点的 end 值上加一
        node.end++;
    }
​
    /**
     * 查询一个单词在树中出现了多少次
     * @param word 待查询的单词
     * @return 出现的次数
     */
    public int search(String word) {
        if (word == null) {
            return 0;
        }
        char[] chars = word.toCharArray();
        TrieNode node = root;
        for (char ch : chars) {
            int path = ch - 'a';
            // 如果中途路径断了,说明这个单词肯定不存在
            if (node.nexts[path] == null) {
                return 0;
            }
            node = node.nexts[path];
        }
        // 走到头了,返回这个节点的 end 值
        return node.end;
    }
​
    /**
     * 查询以某个前缀开头的单词有多少个
     * @param pre 待查询的前缀
     * @return 以 pre 为前缀的单词数量
     */
    public int startsWith(String pre) {
        if (pre == null) {
            return 0;
        }
        char[] chars = pre.toCharArray();
        TrieNode node = root;
        for (char ch : chars) {
            int path = ch - 'a';
            // 如果中途路径断了,说明没有以此为前缀的单词
            if (node.nexts[path] == null) {
                return 0;
            }
            node = node.nexts[path];
        }
        // 走到头了,返回这个节点的 pass 值
        return node.pass;
    }
​
    /**
     * 从树中删除一个单词
     * @param word 待删除的单词
     */
    public void delete(String word) {
        // 首先要确保这个单词存在
        if (search(word) > 0) {
            char[] chars = word.toCharArray();
            TrieNode node = root;
            node.pass--; // 根节点的pass值先减
            for (char ch : chars) {
                int path = ch - 'a';
                // 沿途节点的 pass 值减一
                // 如果某个节点的 pass 值减到0,说明没有单词再经过它了
                // 那么这个节点和它后面的所有子节点都可以被垃圾回收器回收了
                // Java中直接断开引用即可
                if (--node.nexts[path].pass == 0) {
                    node.nexts[path] = null;
                    return; // 后面的都断了,直接返回
                }
                node = node.nexts[path];
            }
            // 最后一个节点的 end 值减一
            node.end--;
        }
    }
}

2. 静态数组方式实现(OI/ACM 风格,性能极致)

在一些算法竞赛中,为了追求极致的性能(避免new对象带来的开销和内存碎片),会用静态数组来模拟前缀树。

思路:

  • 用一个二维数组 trie[MAX_NODES][26] 来模拟节点间的连接关系。trie[u][c] 表示节点 u 通过字符 c 到达的子节点的编号。

  • 用一个 end[]pass[] 数组来存储每个节点的信息。

  • 用一个全局变量 idx 来动态分配新的节点编号。

public class TrieArray {
    // 假设最多有 100000 个节点,每个节点最多26个分支
    private static final int MAX_NODES = 100001; 
    private int[][] trie;
    private int[] end;
    private int[] pass;
    private int idx; // 当前分配到的节点编号,0号是根节点
​
    public TrieArray() {
        trie = new int[MAX_NODES][26];
        end = new int[MAX_NODES];
        pass = new int[MAX_NODES];
        // 0是根节点,所以从1开始分配
        idx = 1; 
    }

    // 清空,以便复用
    public void clear() {
        for (int i = 0; i < idx; i++) {
            end[i] = 0;
            pass[i] = 0;
            for(int j=0; j<26; j++){
                trie[i][j] = 0;
            }
        }
        idx = 1;
    }
​
    public void insert(String s) {
        int p = 0; // 从根节点出发
        pass[p]++;
        for (int i = 0; i < s.length(); i++) {
            int c = s.charAt(i) - 'a';
            if (trie[p][c] == 0) {
                // 如果没有路,就新建一个节点
                trie[p][c] = idx++;
            }
            p = trie[p][c]; // 移动到子节点
            pass[p]++;
        }
        end[p]++;
    }
​
    public int search(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int c = s.charAt(i) - 'a';
            if (trie[p][c] == 0) {
                return 0;
            }
            p = trie[p][c];
        }
        return end[p];
    }
}

优点:速度快,内存连续。 缺点:空间需要预估,不够灵活,代码可读性稍差。

3. 字符种类很多时的优化(哈希表)

如果字符集不是固定的26个小写字母,而是包含大写字母、数字、符号,甚至是Unicode字符,用长度为26的数组就不行了。即便用长度为256(ASCII)或65536(Unicode)的数组,也会造成巨大的空间浪费,因为大部分路径都是空的。

解决方案:用哈希表(HashMap)代替数组来存储子节点。

import java.util.HashMap;
​
class TrieNodeWithMap {
    public int pass;
    public int end;
    // 使用 HashMap 存储子节点,key是字符,value是子节点
    public HashMap<Character, TrieNodeWithMap> nexts;
​
    public TrieNodeWithMap() {
        pass = 0;
        end = 0;
        nexts = new HashMap<>();
    }
}
​
// 主体类的 insert, search 等方法的逻辑几乎不变
// 只是把 nexts[path] 的操作换成 nexts.get(ch) 和 nexts.put(ch, newNode)
public class TrieWithMap {
    private TrieNodeWithMap root;
    // ...
    public void insert(String word) {
        // ...
        TrieNodeWithMap node = root;
        // ...
        for (char ch : chars) {
            if (!node.nexts.containsKey(ch)) {
                node.nexts.put(ch, new TrieNodeWithMap());
            }
            node = node.nexts.get(ch);
            // ...
        }
        // ...
    }
    // search 和 startsWith 类似修改
}

Part 3: 实战题目演练

光说不练假把式,我们来看几个经典的例题,感受一下前缀树的威力。

题目1:实现 Trie (前缀树) (LeetCode 208)

问题重述 实现一个 Trie 类,包含 insert, search, 和 startsWith 三个方法。

分析与思路 这道题就是前缀树的模板题,是“裸”的前缀树。它不要求我们统计单词出现的次数,只需要知道“有”还是“没有”。所以我们可以简化节点信息,只需要一个 isEnd 的布尔值标记即可。

  • insert(word): 遍历单词字符,沿着树向下走。如果路径不存在,就创建新节点。走完后,将最后一个节点标记为 isEnd = true

  • search(word): 遍历单词字符,沿着树向下走。如果中途路径断了,返回 false。如果走完了,必须检查最后一个节点是否是 isEnd。比如树里有 "apple",查 "app" 应该返回 false

  • startsWith(prefix): 遍历前缀字符,沿着树向下走。只要路径不断,就能走完,返回 true

OI/ACM 风格 Java 代码

class Trie {
    private TrieNode root;
​
    private class TrieNode {
        private TrieNode[] children;
        private boolean isEnd;
​
        public TrieNode() {
            children = new TrieNode[26]; // 26个英文小写字母
            isEnd = false;
        }
    }
​
    public Trie() {
        root = new TrieNode();
    }
​
    public void insert(String word) {
        TrieNode node = root;
        for (char c : word.toCharArray()) {
            int index = c - 'a';
            if (node.children[index] == null) {
                node.children[index] = new TrieNode();
            }
            node = node.children[index];
        }
        node.isEnd = true;
    }
​
    public boolean search(String word) {
        TrieNode node = searchPrefix(word);
        return node != null && node.isEnd;
    }
​
    public boolean startsWith(String prefix) {
        return searchPrefix(prefix) != null;
    }
​
    private TrieNode searchPrefix(String prefix) {
        TrieNode node = root;
        for (char c : prefix.toCharArray()) {
            int index = c - 'a';
            if (node.children[index] == null) {
                return null;
            }
            node = node.children[index];
        }
        return node;
    }
}

题目2:数组中两个数的最大异或值 (LeetCode 421)

问题重述 给定一个非空整数数组 nums,找到 nums[i] XOR nums[j] 的最大结果,其中 0 ≤ i ≤ j < n

分析与思路演进

  1. 暴力解法:双重循环,计算每对数的异或值,更新最大值。时间复杂度 O(N²),对于 10^5 级别的数据量,肯定超时。

  2. 前缀树优化:这是前缀树的一个非常巧妙的应用!我们不是在存字符串,而是在存数字的二进制表示。一个 32 位整数可以看作是一个长度为 32 的 "01" 字符串。

    核心思想:贪心! 为了让异或值最大,我们希望结果的高位尽可能为 1

    假设我们当前在处理数组中的数字 num,我们希望在已经处理过的数字中,找到一个 x,使得 num XOR x 最大。 我们从 num 的最高位(第 31 位)开始看:

    • 如果 num 的第 i 位是 1,我们最希望 x 的第 i 位是 0,这样异或结果的第 i 位就是 1

    • 如果 num 的第 i 位是 0,我们最希望 x 的第 i 位是 1,这样异或结果的第 i 位也是 1

这个过程可以用前缀树完美实现:

  1. 建树:遍历数组,将每个数字的 32 位二进制表示插入到一棵前缀树中。这棵树的边只有 0 和 1 两种。

  2. 查询:再次遍历数组,对于每个数字 num,在前缀树中为它寻找最佳匹配。从高位到低位遍历 num 的每一位:

    • 获取 num 的当前位 b

    • 我们期望的路径是 1-b。检查前缀树中,当前节点是否有走向 1-b 的路。

    • 如果存在:太好了!我们贪心成功。走这条路,并且把当前这一位的贡献 (1 << i) 加到当前结果里。

    • 如果不存在:没办法,只能走 b 这条路,这一位对结果的贡献是 0。

  3. 对每个 num 都找到其最大异或对,然后取所有结果中的最大值。

OI/ACM 风格 Java 代码

class Trie {
    private Trie[] children = new Trie[2];
​
    public Trie() {
​
    }
​
    public void insert(int x) {
        Trie node = this;
        for (int i = 30; i >= 0; --i) {
            int v = x >> i & 1;
            if (node.children[v] == null) {
                node.children[v] = new Trie();
            }
            node = node.children[v];
        }
    }
​
    public int search(int x) {
        Trie node = this;
        int ans = 0;
        for (int i = 30; i >= 0; --i) {
            int v = x >> i & 1;
            if (node.children[v ^ 1] != null) {
                ans |= 1 << i;
                node = node.children[v ^ 1];
            } else {
                node = node.children[v];
            }
        }
        return ans;
    }
}
​
class Solution {
    public int findMaximumXOR(int[] nums) {
        Trie trie = new Trie();
        int ans = 0;
        for (int x : nums) {
            trie.insert(x);
            ans = Math.max(ans, trie.search(x));
        }
        return ans;
    }
}