如何用 Trie 加快文本查找

最近做个小练习,需要做实现文本查找。最简单的方案当然是列表遍历筛选了,这也是我在日常开发中的首选方案,毕竟前端通常处理的数据量非常少。但是当处理的数据量过大的时候,列表遍历就太低效了,此时需要更高效的算法。

这次我选了 Trie 来实现文本查找。

Trie (读 try) 取自 retrieval 这个单词,即获取,提取的意思。它是一种树形数据结构,常用来对文本进行快速查找。

原理

Trie 会把所有文本按照字符顺序进行树形层级排列,例如,一组单词 ['abet', 'abode', 'abort'] 表示成树形层级就是

🌲🌲🌲🌲🌲🌲
    a
    |
    b
  /   \
 e     o
 |    / \
 t   d   r
     |   |
     e   t

我们在查找 abo 开头的单词的时候就可以直接跳过 e 这个分支了,而不是像列表遍历那样一个个找。

实现一个简单的 Trie

一,节点信息

首先用一个 class 来表示一个节点。每个节点带的信息见以下代码注释:

class TrieNode {
  constructor(char) {
    this.char = char; // 当前字符
    this.validWord = false; // 当前字符是否是完整单词的末尾
    this.parent = null; // 是否有父节点
    this.children = []; // 所有子节点
  }
}

二,添加单词

接着我们可以基于以上节点定义来构建 Trie 了,首先要实现添加一个单词到 Trie 的方法:

class Trie {
  constructor() {
    this.root = new TrieNode('');
  }

  add(word) {
    let current = this.root; // current 用来表示查找指针,从根节点开始
    // 遍历单词
    for (let i = 0; i < word.length; i += 1) {
      const ch = word[i];
      let found = false;

      // 遍历当前节点的子节点
      // 下面 for 循环的写法是微观性能优化 hack
      for (let j = current.children.length; j--; ) {
        const child = current.children[j];
        if (child.char === ch) {
          found = true;
          // 若找到了匹配的字符,则把查找指针指向匹配到的子节点
          current = child;
          break;
        }
      }

      // 若当前要匹配的字符不在子节点列表内,则创建新的节点,并加入列表
      if (!found) {
        current.children.push(new TrieNode(ch));

        const newNode = current.children[current.children.length - 1];

        newNode.parent = current;
        // 指针移到新节点
        current = newNode;
      }
    }
    // 操作完成后,此时指针应当指向单词最后一个字符所在的节点
    // 将此节点标记为完整单词
    current.validWord = true;
  }
}

三,删除单词

我们还需要提供方法将某个单词从 Trie 里面删除掉。

  delete(word) {
    let current = this.root;

    for (let i = 0; i < word.length; i += 1) {
      const ch = word[i];
      let found = false;

      for (let j = current.children.length; j--; ) {
        const child = current.children[j];

        if (child.char === ch) {
          found = true;
          current = child;
          break;
        }
      }

      if (!found) {
        // 只要有一个字符不匹配,则当前单词不在 Trie 里面,终止操作
        return;
      }
    }

    // 上面 for 循环完成后,我们就遍历到待删除单词最后一个字符所在节点了
    current.validWord = false;

    let stop = false;
    while (!stop) {
      if (
        current.children.length === 0 &&
        !current.validWord &&
        current.parent
      ) {
        // 层层往上操作父节点
        const { parent } = current;
        const childIndex = parent.children.indexOf(current);
        const end = parent.children.length - 1;

        // 找到父节点,取到父节点的所有子节点
        // 把当前节点和子节点列表的最后一项位置互换
        [parent.children[childIndex], parent.children[end]] = [
          parent.children[end],
          parent.children[childIndex]
        ];

        // 此时待删除节点已处于列表末尾,pop 出去即可
        parent.children.pop();

        // 指针往上走
        current = parent;
      } else {
        // 条件不满足时,终止 while 循环
        stop = true;
      }
    }
  }

四,查找单词

我在准备实现 Trie 时看了很多入门教程,发现大部分 Trie 的所谓搜索仅仅是判断某个单词是否包含在 Trie 里面,比如 GeeksForGeeks 这篇,而我要实现的搜索是完整匹配。比如,我要搜索数据 ['facade', 'face', 'fabric', 'fetch', 'female', 'fear'] 里面 fa 开头的单词,输入 fa,应该返回 ['facade', 'face', 'fabric']。下面是实现:

 search(input) {
    // 记录 inputMirror 这一步其实非必要,我只是想让匹配到的单词和输入的字符能完全一致
    // 搜索匹配我用了大小写不敏感方案,如果不记录输入,会出现输入Fa,返回 fa 的情况
    const inputMirror = [];
    let current = this.root;

    for (let i = 0; i < input.length; i += 1) {
      const ch = input.charAt(i);
      let found = false;

      for (let j = current.children.length; j--;) {
        const child = current.children[j];

        if (child.char.toLowerCase() === ch.toLowerCase()) {
          found = true;
          current = child;
          inputMirror.push(child.char);
          break;
        }
      }

      if (!found) {
        return [];
      }
    }

    // 上面操作结束后,指针应当指向输入字符串的最后一个字符所在节点
    // 比如输入 fa,此时指针指向 a

    const match = []; // 用来储存匹配到的单词
    const tracker = []; // 追踪查找到的节点字符

    function traverse(node) {
      tracker.push(node.char);

      if (node.validWord) {
        // 如果到了匹配单词末尾,则从输入源取到字符串,再拼接起来
        const temp = inputMirror.slice(0, input.length - 1);
        // 此时 tracker 里面应该包含剩余的匹配字符
        // 比如 输入 fa,到 b 节点分支往下,会生成 bric
        temp.push(...tracker);
        // 放到匹配结果中
        match.push(temp.join(''));
      }

      // 对子节点层层递归同样的操作
      node.children.forEach(traverse);

      // 最后进递归栈的函数最先执行下面命令
      // 此时已经到达末尾节点,接着层层往上,一个个清空 tracker
      // 例如输入 fa,匹配到 fabric 之后,tracker 里面是 brick
      // 从 k 开始一层一层往上清空字符,这样匹配到下一个单词 face 时,tracker 是空的
      tracker.pop();
    }

    traverse(current);

    return match;
  }

局限

本文只能算极简的实现,还有很大优化空间。字符的储存可以用二分查找树 (Binary Search Tree),这样可以进一步提升字符查找的效率。理想情况下 Trie 搜索的时间复杂度是 M * log N,log N 是二分查找的时间复杂度,N 是 Trie 里面所有字符的数量,M 是输入字符串的长度。

另外,Trie 有个缺陷就是 Trie 本身构建的数据比较占内存,因为它要构建很多节点的引用关系。在对内存占用敏感的场景里,可以用 Ternary Search Tree 来替代 Trie

完整代码

class TrieNode {
  constructor(char) {
    this.char = char;
    this.validWord = false;
    this.parent = null;
    this.children = [];
  }
}

class Trie {
  // 加个辅助方法偷个懒,一次性加入整个列表
  static addAll(list) {
    const trie = new Trie();
    list.forEach(trie.add.bind(trie));
    return trie;
  }

  constructor() {
    this.root = new TrieNode('');
  }

  add(word) {
    let current = this.root;

    for (let i = 0; i < word.length; i += 1) {
      const ch = word[i];
      let found = false;

      for (let j = current.children.length; j--; ) {
        const child = current.children[j];
        if (child.char === ch) {
          found = true;
          current = child;
          break;
        }
      }

      if (!found) {
        current.children.push(new TrieNode(ch));

        const newNode = current.children[current.children.length - 1];

        newNode.parent = current;

        current = newNode;
      }
    }

    current.validWord = true;
  }

  contains(word) {
    let current = this.root;

    for (let i = 0; i < word.length; i += 1) {
      const ch = word[i];
      let found = false;

      for (let j = current.children.length; j--; ) {
        const child = current.children[j];

        if (child.char.toLowerCase() === ch.toLowerCase()) {
          found = true;
          current = child;
          break;
        }
      }

      if (!found) {
        return false;
      }
    }

    return current.validWord;
  }

  delete(word) {
    let current = this.root;

    for (let i = 0; i < word.length; i += 1) {
      const ch = word[i];
      let found = false;

      for (let j = current.children.length; j--; ) {
        const child = current.children[j];

        if (child.char === ch) {
          found = true;
          current = child;
          break;
        }
      }

      if (!found) {
        return;
      }
    }

    current.validWord = false;

    let stop = false;
    while (!stop) {
      if (
        current.children.length === 0 &&
        !current.validWord &&
        current.parent
      ) {
        const { parent } = current;
        const childIndex = parent.children.indexOf(current);
        const end = parent.children.length - 1;

        [parent.children[childIndex], parent.children[end]] = [
          parent.children[end],
          parent.children[childIndex],
        ];

        parent.children.pop();

        current = parent;
      } else {
        stop = true;
      }
    }
  }

  search(input) {
    const inputMirror = [];
    let current = this.root;

    for (let i = 0; i < input.length; i += 1) {
      const ch = input.charAt(i);
      let found = false;

      for (let j = current.children.length; j--; ) {
        const child = current.children[j];

        if (child.char.toLowerCase() === ch.toLowerCase()) {
          found = true;
          current = child;
          inputMirror.push(child.char);
          break;
        }
      }

      if (!found) {
        return [];
      }
    }

    const match = [];
    const tracker = [];

    function traverse(node) {
      tracker.push(node.char);

      if (node.validWord) {
        const temp = inputMirror.slice(0, input.length - 1);
        temp.push(...tracker);
        match.push(temp.join(''));
      }

      node.children.forEach(traverse);

      tracker.pop();
    }

    traverse(current);

    return match;
  }
}