toktrie_hf_tokenizers/
lib.rs1use anyhow::{anyhow, bail, Result};
2use std::{
3 collections::{HashMap, HashSet},
4 path::Path,
5 sync::Arc,
6};
7use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer};
8use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv};
9
10pub struct ByteTokenizer {
11 pub hf_model: String,
12 pub hf_tokenizer: Tokenizer,
13 info: TokRxInfo,
14 token_bytes: Vec<Vec<u8>>,
15}
16
17fn is_self_mapped(c: char) -> bool {
20 matches!(c, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}')
21}
22
23fn build_char_map() -> HashMap<char, u8> {
24 let mut res = HashMap::default();
25 let mut k = 0x100u32;
26 for byte in 0..=255u8 {
27 let c = byte as char;
28 if is_self_mapped(c) {
29 res.insert(c, byte);
30 } else {
31 res.insert(char::from_u32(k).unwrap(), byte);
32 k += 1;
33 }
34 }
35 res
36}
37
38impl ByteTokenizer {
39 pub fn from_file(name: impl AsRef<Path>) -> Result<ByteTokenizer> {
40 let name_str = name.as_ref().display().to_string();
41 let tok = Tokenizer::from_file(name)
42 .map_err(|e| anyhow!("error loading tokenizer: {}: {}", name_str, e))?;
43 ByteTokenizer::from_tokenizer(tok)
44 }
45
46 pub fn from_json_bytes(bytes: &[u8]) -> Result<ByteTokenizer> {
47 let tok =
48 Tokenizer::from_bytes(bytes).map_err(|e| anyhow!("error loading tokenizer: {}", e))?;
49 ByteTokenizer::from_tokenizer(tok)
50 }
51
52 pub fn from_tokenizer(mut hft: Tokenizer) -> Result<ByteTokenizer> {
53 let mut is_byte_level = false;
54 let mut is_byte_fallback = false;
55 let mut space_ch = ' ';
56
57 if let Some(n) = hft.get_normalizer() {
59 let n = match n {
60 NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new(
61 x.get_normalizers()
62 .iter()
63 .filter_map(|n| match n {
64 NormalizerWrapper::Prepend(_) => None,
65 _ => Some(n.clone()),
66 })
67 .collect(),
68 )),
69 _ => n.clone(),
70 };
71 hft.with_normalizer(Some(n));
72 }
73
74 if let Some(d) = hft.get_decoder() {
75 let v = serde_json::to_value(d).unwrap();
78 if v["type"].as_str() == Some("ByteLevel") {
79 is_byte_level = true;
80 } else if v["type"].as_str() == Some("Sequence") {
81 if let Some(decoders) = v["decoders"].as_array() {
82 for decoder in decoders {
83 if decoder["type"].as_str() == Some("ByteFallback") {
84 is_byte_fallback = true;
85 } else if decoder["type"].as_str() == Some("Replace")
86 && decoder["content"].as_str() == Some(" ")
87 {
88 if let Some(s) = decoder["pattern"]["String"].as_str() {
89 let s: Vec<char> = s.chars().collect();
90 if s.len() == 1 {
91 space_ch = s[0];
92 }
93 }
94 }
95 }
96 }
97 }
98 }
99
100 if !is_byte_fallback && !is_byte_level {
101 bail!("can't determine decoder type: {:?}", hft.get_decoder());
102 }
103
104 let vocab_size = hft.get_vocab_size(true) as u32;
105 let added = hft.get_added_tokens_decoder();
106
107 let mut res = ByteTokenizer {
108 hf_model: "foobar".to_string(),
109 info: TokRxInfo::new(vocab_size, 0),
110 token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(),
111 hf_tokenizer: hft,
112 };
113
114 let mut specials = HashSet::new();
115
116 for (id, info) in added.iter() {
117 if info.special {
118 match info.content.as_str() {
119 "</s>"
120 | "<|endoftext|>"
121 | "<|end_of_text|>"
122 | "<|end▁of▁sentence|>" | "<eos>" => res.info.tok_eos = *id,
124
125 "<|end|>" | "<|eot_id|>" | "<|im_end|>" => res.info.tok_end_of_turn = Some(*id),
126 "<unk>" | "<|unk|>" => res.info.tok_unk = Some(*id),
127 "<pad>" | "<|pad|>" => res.info.tok_pad = Some(*id),
128 _ => {}
129 }
130 specials.insert(*id);
131 } else {
132 res.token_bytes[*id as usize] = info.content.clone().into_bytes();
133 }
134 }
135
136 let char_map = build_char_map();
137
138 for tok_id in 0..vocab_size {
139 if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) {
140 let bytes = if specials.contains(&tok_id) {
141 let mut bytes = tok_name.as_bytes().to_vec();
142 bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER);
143 bytes
144 } else if is_byte_fallback {
145 if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">")
146 {
147 let hex_str = &tok_name[3..5];
149 let byte = u8::from_str_radix(hex_str, 16).unwrap();
150 vec![byte]
151 } else {
152 assert!(!tok_name.starts_with("<0x"));
153 let tok_name = tok_name.replace(space_ch, " ");
154 tok_name.as_bytes().to_vec()
155 }
156 } else if is_byte_level {
157 let bytes: Result<Vec<u8>> = tok_name
158 .chars()
159 .map(|c| {
160 char_map
161 .get(&c)
162 .copied()
163 .ok_or_else(|| anyhow!("missing char: {}", c))
164 })
165 .collect();
166 match bytes {
167 Ok(b) => b,
168 Err(e) => {
169 log::warn!("error: {} for {:?}", e, tok_name);
170 continue;
171 }
172 }
173 } else {
174 panic!();
175 };
176 res.token_bytes[tok_id as usize] = bytes;
177 } else {
178 log::warn!("missing token: {}", tok_id);
179 }
180 }
181
182 Ok(res)
183 }
184
185 pub fn tokrx_info(&self) -> TokRxInfo {
186 self.info
187 }
188 pub fn token_bytes(&self) -> Vec<Vec<u8>> {
189 self.token_bytes.clone()
190 }
191
192 pub fn set_eos_token(&mut self, tok_id: u32) {
193 self.info.tok_eos = tok_id;
194 }
195
196 pub fn into_tok_env(self, n_vocab: Option<usize>) -> Result<TokEnv> {
197 let b = ByteTokenizerEnv::new(self, n_vocab)?;
198 Ok(b.to_env())
199 }
200}
201
202pub struct ByteTokenizerEnv {
203 pub tokenizer: ByteTokenizer,
204 pub tok_trie: TokTrie,
205}
206
207impl ByteTokenizerEnv {
208 pub fn new(tokenizer: ByteTokenizer, n_vocab: Option<usize>) -> Result<ByteTokenizerEnv> {
209 let mut info = tokenizer.tokrx_info();
210 let mut token_bytes = tokenizer.token_bytes();
211 if let Some(n_vocab) = n_vocab {
212 if n_vocab < token_bytes.len() {
213 bail!("vocab size too small; {} vs {}", n_vocab, token_bytes.len());
214 }
215 while n_vocab > token_bytes.len() {
216 token_bytes.push(Vec::new());
217 }
218 info.vocab_size = n_vocab as u32;
219 }
220 let tok_trie = TokTrie::from(&info, &token_bytes);
221 Ok(ByteTokenizerEnv {
222 tokenizer,
223 tok_trie,
224 })
225 }
226
227 pub fn to_env(self) -> TokEnv {
228 Arc::new(self)
229 }
230}
231
232impl TokenizerEnv for ByteTokenizerEnv {
233 fn tok_trie(&self) -> &TokTrie {
234 &self.tok_trie
235 }
236
237 fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
238 self.tok_trie.tokenize_with_greedy_fallback(s, |s| {
239 self.tokenizer
240 .hf_tokenizer
241 .encode(s, false)
242 .expect("tokenizer error")
243 .get_ids()
244 .to_vec()
245 })
246 }
247}