1use anyhow::Result;
6use std::path::Path;
7use tokenizers::{Encoding, Tokenizer as HFTokenizer};
8
9pub struct Tokenizer {
11 tokenizer: HFTokenizer,
12}
13
14impl Tokenizer {
15 pub fn load(path: &Path) -> Result<Self> {
23 let tokenizer =
24 HFTokenizer::from_file(path).map_err(|e| anyhow::anyhow!("加载分词器失败: {}", e))?;
25 Ok(Self { tokenizer })
26 }
27
28 pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Encoding> {
37 self.tokenizer
38 .encode(text, add_special_tokens)
39 .map_err(|e| anyhow::anyhow!("编码失败: {}", e))
40 }
41
42 pub fn encode_batch(
51 &self,
52 texts: Vec<&str>,
53 add_special_tokens: bool,
54 ) -> Result<Vec<Encoding>> {
55 self.tokenizer
56 .encode_batch(texts, add_special_tokens)
57 .map_err(|e| anyhow::anyhow!("批量编码失败: {}", e))
58 }
59
60 pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
69 self.tokenizer
70 .decode(ids, skip_special_tokens)
71 .map_err(|e| anyhow::anyhow!("解码失败: {}", e))
72 }
73
74 pub fn vocab_size(&self) -> usize {
76 self.tokenizer.get_vocab_size(false)
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
88 #[ignore] fn test_encode() {
90 let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
91 let encoding = tokenizer.encode("Hello, world!", true).unwrap();
92 assert!(!encoding.get_ids().is_empty());
93 }
94
95 #[test]
96 #[ignore]
97 fn test_decode() {
98 let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
99 let encoding = tokenizer.encode("Hello, world!", true).unwrap();
100 let decoded = tokenizer.decode(encoding.get_ids(), true).unwrap();
101 assert!(decoded.contains("Hello"));
102 }
103}