scirs2_text/huggingface_compat/
tokenizer.rs1use super::config::HfTokenizerConfig;
7use crate::error::Result;
8use crate::tokenize::Tokenizer;
9use std::collections::HashMap;
10
11pub struct HfTokenizer {
13 tokenizer: Box<dyn Tokenizer>,
15 config: HfTokenizerConfig,
17 vocab: HashMap<String, usize>,
19 reverse_vocab: HashMap<usize, String>,
21}
22
23impl HfTokenizer {
24 pub fn new(tokenizer: Box<dyn Tokenizer>, config: HfTokenizerConfig) -> Self {
26 let mut vocab = HashMap::new();
28 let mut reverse_vocab = HashMap::new();
29
30 for (token_id, token) in config.special_tokens.keys().enumerate() {
32 vocab.insert(token.clone(), token_id);
33 reverse_vocab.insert(token_id, token.clone());
34 }
35
36 Self {
37 tokenizer,
38 config,
39 vocab,
40 reverse_vocab,
41 }
42 }
43
44 pub fn encode(&self, text: &str, add_specialtokens: bool) -> Result<HfEncodedInput> {
46 let mut tokens = self.tokenizer.tokenize(text)?;
47
48 if add_specialtokens {
50 if let Some(bos_token) = &self.config.bos_token {
51 tokens.insert(0, bos_token.clone());
52 }
53 if let Some(eos_token) = &self.config.eos_token {
54 tokens.push(eos_token.clone());
55 }
56 }
57
58 let input_ids: Vec<usize> = tokens
60 .iter()
61 .map(|token| {
62 self.vocab
63 .get(token)
64 .copied()
65 .unwrap_or(self.vocab.get(&self.config.unk_token).copied().unwrap_or(0))
66 })
67 .collect();
68
69 let attention_mask = vec![1; input_ids.len()];
71
72 let token_type_ids = vec![0; input_ids.len()];
74
75 Ok(HfEncodedInput {
76 input_ids,
77 attention_mask,
78 token_type_ids: Some(token_type_ids),
79 tokens,
80 })
81 }
82
83 pub fn encode_batch(
85 &self,
86 texts: &[&str],
87 add_special_tokens: bool,
88 ) -> Result<Vec<HfEncodedInput>> {
89 texts
90 .iter()
91 .map(|text| self.encode(text, add_special_tokens))
92 .collect()
93 }
94
95 pub fn decode(&self, token_ids: &[usize], skip_specialtokens: bool) -> Result<String> {
97 let tokens: Vec<String> = token_ids
98 .iter()
99 .filter_map(|&id| self.reverse_vocab.get(&id))
100 .filter(|token| {
101 if skip_specialtokens {
102 !self.config.special_tokens.contains_key(*token)
103 } else {
104 true
105 }
106 })
107 .cloned()
108 .collect();
109
110 Ok(tokens.join(" "))
111 }
112
113 pub fn vocab_size(&self) -> usize {
115 self.vocab.len()
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct HfEncodedInput {
122 pub input_ids: Vec<usize>,
124 pub attention_mask: Vec<i32>,
126 pub token_type_ids: Option<Vec<usize>>,
128 pub tokens: Vec<String>,
130}