Skip to main content

xore_ai/
tokenizer.rs

1//! 分词器封装
2//!
3//! 提供统一的分词接口
4
5use anyhow::Result;
6use std::path::Path;
7use tokenizers::{Encoding, Tokenizer as HFTokenizer};
8
9/// 分词器
10pub struct Tokenizer {
11    tokenizer: HFTokenizer,
12}
13
14impl Tokenizer {
15    /// 从文件加载分词器
16    ///
17    /// # 参数
18    /// - `path`: tokenizer.json 文件路径
19    ///
20    /// # 返回
21    /// 加载好的分词器
22    pub fn load(path: &Path) -> Result<Self> {
23        let tokenizer =
24            HFTokenizer::from_file(path).map_err(|e| anyhow::anyhow!("加载分词器失败: {}", e))?;
25        Ok(Self { tokenizer })
26    }
27
28    /// 分词并编码文本
29    ///
30    /// # 参数
31    /// - `text`: 输入文本
32    /// - `add_special_tokens`: 是否添加特殊 token(如 [CLS], [SEP])
33    ///
34    /// # 返回
35    /// 编码结果
36    pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Encoding> {
37        self.tokenizer
38            .encode(text, add_special_tokens)
39            .map_err(|e| anyhow::anyhow!("编码失败: {}", e))
40    }
41
42    /// 批量编码
43    ///
44    /// # 参数
45    /// - `texts`: 文本列表
46    /// - `add_special_tokens`: 是否添加特殊 token
47    ///
48    /// # 返回
49    /// 编码结果列表
50    pub fn encode_batch(
51        &self,
52        texts: Vec<&str>,
53        add_special_tokens: bool,
54    ) -> Result<Vec<Encoding>> {
55        self.tokenizer
56            .encode_batch(texts, add_special_tokens)
57            .map_err(|e| anyhow::anyhow!("批量编码失败: {}", e))
58    }
59
60    /// 解码 token IDs 为文本
61    ///
62    /// # 参数
63    /// - `ids`: token ID 列表
64    /// - `skip_special_tokens`: 是否跳过特殊 token
65    ///
66    /// # 返回
67    /// 解码后的文本
68    pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
69        self.tokenizer
70            .decode(ids, skip_special_tokens)
71            .map_err(|e| anyhow::anyhow!("解码失败: {}", e))
72    }
73
74    /// 获取词汇表大小
75    pub fn vocab_size(&self) -> usize {
76        self.tokenizer.get_vocab_size(false)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    // 注意:这些测试需要实际的 tokenizer.json 文件才能运行
85    // 在 CI 环境中可能需要跳过或使用 mock
86
87    #[test]
88    #[ignore] // 需要实际的 tokenizer 文件
89    fn test_encode() {
90        let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
91        let encoding = tokenizer.encode("Hello, world!", true).unwrap();
92        assert!(!encoding.get_ids().is_empty());
93    }
94
95    #[test]
96    #[ignore]
97    fn test_decode() {
98        let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
99        let encoding = tokenizer.encode("Hello, world!", true).unwrap();
100        let decoded = tokenizer.decode(encoding.get_ids(), true).unwrap();
101        assert!(decoded.contains("Hello"));
102    }
103}