1use crate::model::*;
2use crate::openai_public::find_encoding_constructor;
3use crate::CoreBPE;
4use fancy_regex::Regex;
5use rayon::prelude::*;
6use rustc_hash::FxHashMap;
7use std::cmp::max;
8use std::collections::{HashMap, HashSet};
9use std::fmt::{Debug, Display, Formatter};
10use std::hash::Hash;
11
12pub type Result<T> = std::result::Result<T, EncodeError>;
13
14pub fn get_encoding(encoding_name: &str) -> Result<Encoding> {
17 match find_encoding_constructor(encoding_name) {
18 Some(func) => Encoding::new(func()),
19 None => Err(EncodeError::EncodingNameError(encoding_name.to_string())),
20 }
21}
22
23pub fn encoding_for_model(model_name: &str) -> Result<Encoding> {
25 let encoding_opt = MODEL_TO_ENCODING
26 .get(model_name)
27 .map(|&encoding| get_encoding(encoding));
28 if let Some(encoding) = encoding_opt {
29 return encoding;
30 }
31
32 for (&model_prefix, &model_encoding_name) in MODEL_PREFIX_TO_ENCODING.iter() {
36 if model_name.starts_with(model_prefix) {
37 return get_encoding(model_encoding_name);
38 }
39 }
40
41 Err(EncodeError::ModelNameError(model_name.to_string()))
42}
43
44pub struct EncodingParam {
45 name: String,
46 pat_str: String,
47 mergeable_ranks: HashMap<Vec<u8>, usize>,
48 special_tokens: HashMap<String, usize>,
49 explicit_n_vocab: Option<usize>,
50}
51
52impl EncodingParam {
53 pub fn new(
54 name: String,
55 pat_str: String,
56 mergeable_ranks: HashMap<Vec<u8>, usize>,
57 special_tokens: HashMap<String, usize>,
58 explicit_n_vocab: Option<usize>,
59 ) -> Self {
60 EncodingParam {
61 name,
62 pat_str,
63 mergeable_ranks,
64 special_tokens,
65 explicit_n_vocab,
66 }
67 }
68}
69
70pub struct Encoding {
71 name: String,
72 _pat_str: String,
73 special_tokens: HashMap<String, usize>,
74
75 max_token_value: usize,
76 core_bpe: CoreBPE,
77}
78
79impl Debug for Encoding {
80 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81 write!(f, "<Encoding '{:?}'>", self.name)
82 }
83}
84
85impl Display for Encoding {
87 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88 write!(f, "<Encoding '{:?}'>", self.name)
89 }
90}
91
92impl Encoding {
94 fn new(param: EncodingParam) -> Result<Self> {
111 let max_token_value = max(
112 param
113 .mergeable_ranks
114 .values()
115 .max()
116 .copied()
117 .unwrap_or_default(),
118 param
119 .special_tokens
120 .values()
121 .max()
122 .copied()
123 .unwrap_or_default(),
124 );
125 if let Some(n_vocab) = param.explicit_n_vocab {
126 assert_eq!(
127 param.mergeable_ranks.len() + param.special_tokens.len(),
128 n_vocab
129 );
130 assert_eq!(max_token_value, n_vocab - 1);
131 }
132
133 let core_bpe = CoreBPE::new(
134 convert_to_fx_hashmap(¶m.mergeable_ranks),
135 convert_to_fx_hashmap(¶m.special_tokens),
136 param.pat_str.as_str(),
137 )?;
138
139 Ok(Encoding {
140 name: param.name,
141 _pat_str: param.pat_str,
142 special_tokens: param.special_tokens,
143 max_token_value,
144 core_bpe,
145 })
146 }
147}
148
149impl Encoding {
151 pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
155 self.core_bpe._encode_ordinary_native(text)
156 }
157
158 pub fn encode_ordinary_batch(&self, texts: Vec<&str>) -> Vec<Vec<usize>> {
162 texts
163 .par_iter()
164 .map(|&txt| self.encode_ordinary(txt))
165 .collect()
166 }
167
168 pub fn encode(
180 &self,
181 text: &str,
182 allowed_special: AllowedSpecial,
183 disallowed_special: DisallowedSpecial,
184 ) -> Result<Vec<usize>> {
185 let allowed_special_set = match allowed_special {
186 AllowedSpecial::All => self.special_tokens_set(),
187 AllowedSpecial::Allowed(allowed) => allowed,
188 };
189 let disallowed_special_set = match disallowed_special {
190 DisallowedSpecial::All => self
191 .special_tokens_set()
192 .difference(&allowed_special_set)
193 .copied()
194 .collect(),
195 DisallowedSpecial::Disallowed(disallowed) => disallowed,
196 };
197
198 if !disallowed_special_set.is_empty() {
199 let re = special_token_regex(disallowed_special_set)?;
200 if let Ok(Some(cap)) = re.captures(text) {
201 return Err(EncodeError::SpecialTokenError(String::from(
202 cap.get(0).unwrap().as_str(),
203 )));
204 }
205 }
206
207 Ok(self.core_bpe._encode_native(text, &allowed_special_set).0)
208 }
209
210 pub fn encode_batch(
214 &self,
215 texts: Vec<&str>,
216 allowed_special: AllowedSpecial,
217 disallowed_special: DisallowedSpecial,
218 ) -> Result<Vec<Vec<usize>>> {
219 let data: Vec<Result<Vec<usize>>> = texts
220 .par_iter()
221 .map(|&txt| self.encode(txt, allowed_special.clone(), disallowed_special.clone()))
222 .collect();
223
224 let mut res = Vec::new();
225 for item in data {
226 res.push(item?);
227 }
228 Ok(res)
229 }
230
231 pub fn encode_with_unstable(
236 &self,
237 text: &str,
238 allowed_special: AllowedSpecial,
239 disallowed_special: DisallowedSpecial,
240 ) -> Result<(Vec<usize>, Vec<Vec<usize>>)> {
241 let allowed_special_set = match allowed_special {
242 AllowedSpecial::All => self.special_tokens_set(),
243 AllowedSpecial::Allowed(allowed) => allowed,
244 };
245 let disallowed_special_set = match disallowed_special {
246 DisallowedSpecial::All => self
247 .special_tokens_set()
248 .difference(&allowed_special_set)
249 .copied()
250 .collect(),
251 DisallowedSpecial::Disallowed(disallowed) => disallowed,
252 };
253
254 if !disallowed_special_set.is_empty() {
255 let re = special_token_regex(disallowed_special_set)?;
256 if let Ok(Some(cap)) = re.captures(text) {
257 return Err(EncodeError::SpecialTokenError(String::from(
258 cap.get(0).unwrap().as_str(),
259 )));
260 }
261 }
262
263 let (tokens, completions) = self
264 .core_bpe
265 ._encode_unstable_native(text, &allowed_special_set);
266 let completions = completions.into_iter().collect();
267 Ok((tokens, completions))
268 }
269
270 pub fn encode_single_token(&self, piece: &[u8]) -> Result<usize> {
274 if let Some(token) = self.core_bpe.encoder.get(piece).copied() {
275 return Ok(token);
276 }
277 if let Ok(piece_str) = std::str::from_utf8(piece) {
278 if let Some(token) = self.core_bpe.special_tokens_encoder.get(piece_str).copied() {
279 return Ok(token);
280 }
281 }
282 Err(EncodeError::TokenEncodeError(piece.to_owned()))
283 }
284}
285
286impl Encoding {
288 pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
290 self.core_bpe._decode_native(tokens)
291 }
292
293 pub fn decode_bytes_batch(self, batch: &[&[usize]]) -> Vec<Vec<u8>> {
295 batch
296 .par_iter()
297 .map(|tokens| self.decode_bytes(tokens))
298 .collect()
299 }
300
301 pub fn decode(&self, tokens: &[usize], mode: DecodeMode) -> Result<String> {
309 let bytes = self.decode_bytes(tokens);
310 match mode {
311 DecodeMode::Strict => String::from_utf8(bytes).map_err(EncodeError::ConvertStringError),
312 DecodeMode::Replace => Ok(String::from_utf8_lossy(&bytes).to_string()),
313 }
314 }
315
316 pub fn decode_batch(&self, batch: &[&[usize]], mode: DecodeMode) -> Vec<Result<String>> {
318 batch
319 .par_iter()
320 .map(|tokens| self.decode(tokens, mode.clone()))
321 .collect()
322 }
323
324 pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>> {
327 if let Some(bytes) = self.core_bpe.decoder.get(&token) {
328 return Ok(bytes.to_vec());
329 }
330 if let Some(bytes) = self.core_bpe.special_tokens_decoder.get(&token) {
331 return Ok(bytes.to_vec());
332 }
333 Err(EncodeError::TokenNotFoundError(token))
334 }
335
336 pub fn decode_tokens_bytes(&self, tokens: &Vec<usize>) -> Result<Vec<Vec<u8>>> {
339 let data: Vec<Result<Vec<u8>>> = tokens
340 .par_iter()
341 .map(|&token| self.decode_single_token_bytes(token))
342 .collect();
343
344 let mut res = Vec::new();
345 for item in data {
346 res.push(item?);
347 }
348 Ok(res)
349 }
350
351 pub fn decode_with_offsets(self, tokens: &Vec<usize>) -> Result<(String, Vec<usize>)> {
360 let token_bytes = self.decode_tokens_bytes(tokens)?;
361 let mut text_len = 0;
362 let mut offsets = vec![];
363
364 for token in token_bytes {
365 let offset = if token[0] >= 0x80 && token[0] < 0xC0 {
366 max(0, text_len - 1)
367 } else {
368 max(0, text_len)
369 };
370 offsets.push(offset);
371 text_len += token
372 .iter()
373 .map(|&c| if c < 0x80 || c >= 0xC0 { 1 } else { 0 })
374 .sum::<usize>();
375 }
376
377 let text = self.decode(tokens, DecodeMode::Strict)?;
378
379 Ok((text, offsets))
380 }
381}
382
383impl Encoding {
385 pub fn name(&self) -> &str {
387 self.name.as_str()
388 }
389
390 pub fn token_byte_values(&self) -> Vec<Vec<u8>> {
392 self.core_bpe
393 .sorted_token_bytes
394 .iter()
395 .map(|x| x.to_vec())
396 .collect()
397 }
398
399 pub fn eot_token(&self) -> Option<usize> {
400 self.special_tokens.get("<|endoftext|>").copied()
401 }
402
403 pub fn n_vocab(&self) -> usize {
405 self.max_token_value + 1
406 }
407
408 pub fn special_tokens_set(&self) -> HashSet<&str> {
410 HashSet::from_iter(self.special_tokens.keys().map(|k| k.as_str()))
411 }
412}
413
414fn special_token_regex(tokens: HashSet<&str>) -> Result<Regex> {
416 let inner: Vec<_> = tokens.iter().map(|&t| regex::escape(t)).collect();
417 let re = Regex::new(format!("({})", inner.join("|")).as_str())?;
418 Ok(re)
419}
420
421fn convert_to_fx_hashmap<K, V>(origin: &HashMap<K, V>) -> FxHashMap<K, V>
422where
423 K: Hash + Eq + PartialEq + Clone,
424 V: Clone,
425{
426 let mut res: FxHashMap<K, V> = FxHashMap::default();
427 origin
428 .iter()
429 .for_each(|(k, v)| _ = res.insert(k.clone(), v.clone()));
430 res
431}