1use std::{
13 collections::HashMap,
14 fs::{read_to_string, File},
15 io::{prelude::*, BufReader},
16 ops::{Deref, DerefMut},
17 path::{Path, PathBuf},
18};
19
20use serde::de::DeserializeOwned;
21use serde::{Deserialize, Serialize};
22
23use crate::utils::iter::ResultShunt;
24use crate::utils::parallelism::*;
25use crate::utils::progress::{ProgressBar, ProgressStyle};
26
27mod added_vocabulary;
28mod encoding;
29pub mod normalizer;
30pub mod pattern;
31pub mod pre_tokenizer;
32mod serialization;
33
34pub use crate::decoders::DecoderWrapper;
36pub use crate::models::ModelWrapper;
37pub use crate::normalizers::NormalizerWrapper;
38pub use crate::pre_tokenizers::PreTokenizerWrapper;
39pub use crate::processors::PostProcessorWrapper;
40pub use crate::utils::iter::LinesWithEnding;
42pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
43pub use crate::utils::truncation::{
44 truncate_encodings, TruncationDirection, TruncationParams, TruncationStrategy,
45};
46pub use added_vocabulary::*;
47pub use encoding::*;
48pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
49pub use pre_tokenizer::*;
50
51pub type Error = Box<dyn std::error::Error + Send + Sync>;
52pub type Result<T> = std::result::Result<T, Error>;
53pub type Offsets = (usize, usize);
54
55pub trait Normalizer {
57 fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
58}
59
60pub trait PreTokenizer {
66 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()>;
67}
68
69pub trait Model {
71 type Trainer: Trainer + Sync;
72 fn tokenize(&self, sequence: &str) -> Result<Vec<Token>>;
75 fn token_to_id(&self, token: &str) -> Option<u32>;
77 fn id_to_token(&self, id: u32) -> Option<String>;
79 fn get_vocab(&self) -> HashMap<String, u32>;
81 fn get_vocab_size(&self) -> usize;
83 fn save(&self, folder: &Path, prefix: Option<&str>) -> Result<Vec<PathBuf>>;
86 fn get_trainer(&self) -> <Self as Model>::Trainer;
88}
89
90pub trait PostProcessor {
93 fn added_tokens(&self, is_pair: bool) -> usize;
95 fn process(
97 &self,
98 encoding: Encoding,
99 pair_encoding: Option<Encoding>,
100 add_special_tokens: bool,
101 ) -> Result<Encoding> {
102 let mut encodings = if let Some(pair_encoding) = pair_encoding {
103 vec![encoding, pair_encoding]
104 } else {
105 vec![encoding]
106 };
107 encodings.iter_mut().enumerate().for_each(|(i, encoding)| {
108 encoding.set_sequence_id(i);
109 encoding
110 .get_overflowing_mut()
111 .iter_mut()
112 .for_each(|encoding| encoding.set_sequence_id(i));
113 encoding.set_type_ids(vec![i as u32; encoding.len()]);
114 });
115
116 let encodings = self.process_encodings(encodings, add_special_tokens)?;
117 Ok(Encoding::merge(encodings, false))
118 }
119
120 fn process_encodings(
122 &self,
123 encodings: Vec<Encoding>,
124 add_special_tokens: bool,
125 ) -> Result<Vec<Encoding>>;
126}
127impl dyn PostProcessor {
128 pub fn default_process(
129 encodings: Vec<Encoding>,
130 _add_special_tokens: bool,
131 ) -> Result<Vec<Encoding>> {
132 match encodings.len() {
133 1 => Ok(encodings),
134 _ => {
135 let mut final_encoding = Encoding::default();
136 for (i, mut encoding) in encodings.into_iter().enumerate() {
137 encoding.set_sequence_id(i);
138 final_encoding.merge_with(encoding, false);
139 }
140 Ok(vec![final_encoding])
141 }
142 }
143 }
144}
145
146#[derive(thiserror::Error, Debug)]
147pub enum ProcessorError {
148 #[error("encodings vector length must be either 1 or 2")]
149 InvalidEncodingsVecLength,
150}
151
152pub trait Decoder {
154 fn decode(&self, tokens: Vec<String>) -> Result<String> {
155 let results = self.decode_chain(tokens)?;
156 Ok(results.join(""))
157 }
158 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>>;
159}
160
161pub trait Trainer {
164 type Model: Model + Sized;
165 fn should_show_progress(&self) -> bool;
167 fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
170 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
173 where
174 I: Iterator<Item = S> + Send,
175 S: AsRef<str> + Send,
176 F: Fn(&str) -> Result<Vec<String>> + Sync;
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub struct Token {
181 pub id: u32,
182 pub value: String,
183 pub offsets: (usize, usize),
184}
185impl Token {
186 pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
187 Self { id, value, offsets }
188 }
189}
190
191use std::borrow::Cow;
192#[derive(Debug, Clone)]
193pub enum InputSequence<'s> {
194 Raw(Cow<'s, str>),
195 PreTokenized(Cow<'s, [&'s str]>),
196 PreTokenizedOwned(Cow<'s, [String]>),
197 PreTokenizedCow(Cow<'s, [Cow<'s, str>]>),
198}
199
200impl<'s> From<Cow<'s, str>> for InputSequence<'s> {
201 fn from(input: Cow<'s, str>) -> Self {
202 Self::Raw(input)
203 }
204}
205
206impl<'s> From<&'s str> for InputSequence<'s> {
207 fn from(input: &'s str) -> Self {
208 Self::Raw(Cow::Borrowed(input))
209 }
210}
211
212impl From<String> for InputSequence<'_> {
213 fn from(input: String) -> Self {
214 Self::Raw(Cow::Owned(input))
215 }
216}
217
218impl<'s> From<&'s [&'s str]> for InputSequence<'s> {
219 fn from(input: &'s [&'s str]) -> Self {
220 Self::PreTokenized(Cow::Borrowed(input))
221 }
222}
223
224impl<'s> From<Vec<&'s str>> for InputSequence<'s> {
225 fn from(input: Vec<&'s str>) -> Self {
226 Self::PreTokenized(Cow::Owned(input))
227 }
228}
229
230impl<'s> From<&'s [String]> for InputSequence<'s> {
231 fn from(input: &'s [String]) -> Self {
232 Self::PreTokenizedOwned(Cow::Borrowed(input))
233 }
234}
235
236impl<'s> From<Vec<String>> for InputSequence<'s> {
237 fn from(input: Vec<String>) -> Self {
238 Self::PreTokenizedOwned(Cow::Owned(input))
239 }
240}
241
242impl<'s> From<Vec<Cow<'s, str>>> for InputSequence<'s> {
243 fn from(input: Vec<Cow<'s, str>>) -> Self {
244 Self::PreTokenizedCow(Cow::Owned(input))
245 }
246}
247
248impl<'s> From<&'s [Cow<'s, str>]> for InputSequence<'s> {
249 fn from(input: &'s [Cow<'s, str>]) -> Self {
250 Self::PreTokenizedCow(Cow::Borrowed(input))
251 }
252}
253
254#[derive(Debug, Clone)]
255pub enum EncodeInput<'s> {
256 Single(InputSequence<'s>),
257 Dual(InputSequence<'s>, InputSequence<'s>),
258}
259
260impl<'s, I: Into<InputSequence<'s>>> From<I> for EncodeInput<'s> {
261 fn from(input: I) -> Self {
262 Self::Single(input.into())
263 }
264}
265
266impl<'s, I1, I2> From<(I1, I2)> for EncodeInput<'s>
267where
268 I1: Into<InputSequence<'s>>,
269 I2: Into<InputSequence<'s>>,
270{
271 fn from(input: (I1, I2)) -> Self {
272 Self::Dual(input.0.into(), input.1.into())
273 }
274}
275
276#[derive(thiserror::Error, Debug)]
277#[error("{0}")]
278pub struct BuilderError(String);
279
280pub struct TokenizerBuilder<M, N, PT, PP, D> {
284 model: Option<M>,
285 normalizer: Option<N>,
286 pre_tokenizer: Option<PT>,
287 post_processor: Option<PP>,
288 decoder: Option<D>,
289
290 added_vocabulary: AddedVocabulary,
291
292 truncation: Option<TruncationParams>,
293 padding: Option<PaddingParams>,
294}
295
296impl<M, N, PT, PP, D> Default for TokenizerBuilder<M, N, PT, PP, D>
297where
298 M: Model,
299 N: Normalizer,
300 PT: PreTokenizer,
301 PP: PostProcessor,
302 D: Decoder,
303{
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309impl<M, N, PT, PP, D> TokenizerBuilder<M, N, PT, PP, D>
310where
311 M: Model,
312 N: Normalizer,
313 PT: PreTokenizer,
314 PP: PostProcessor,
315 D: Decoder,
316{
317 pub fn new() -> Self {
319 Self {
320 model: None,
321 normalizer: None,
322 pre_tokenizer: None,
323 post_processor: None,
324 decoder: None,
325 added_vocabulary: AddedVocabulary::new(),
326 truncation: None,
327 padding: None,
328 }
329 }
330
331 pub fn build(self) -> Result<TokenizerImpl<M, N, PT, PP, D>> {
335 let model = self
336 .model
337 .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
338 Ok(TokenizerImpl {
339 normalizer: self.normalizer,
340 pre_tokenizer: self.pre_tokenizer,
341 model,
342
343 post_processor: self.post_processor,
344 decoder: self.decoder,
345 added_vocabulary: self.added_vocabulary,
346 truncation: self.truncation,
347 padding: self.padding,
348 })
349 }
350
351 #[must_use]
353 pub fn with_model(mut self, model: M) -> Self {
354 self.model = Some(model);
355 self
356 }
357
358 #[must_use]
360 pub fn with_normalizer(mut self, normalizer: Option<N>) -> Self {
361 self.normalizer = normalizer;
362 self
363 }
364
365 #[must_use]
367 pub fn with_pre_tokenizer(mut self, pretokenizer: Option<PT>) -> Self {
368 self.pre_tokenizer = pretokenizer;
369 self
370 }
371
372 #[must_use]
374 pub fn with_post_processor(mut self, post_processor: Option<PP>) -> Self {
375 self.post_processor = post_processor;
376 self
377 }
378
379 #[must_use]
381 pub fn with_decoder(mut self, decoder: Option<D>) -> Self {
382 self.decoder = decoder;
383 self
384 }
385
386 pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self {
388 self.added_vocabulary = added_vocabulary;
389 self
390 }
391
392 #[must_use]
394 pub fn with_truncation(mut self, trunc: Option<TruncationParams>) -> Self {
395 self.truncation = trunc;
396 self
397 }
398
399 #[must_use]
401 pub fn with_padding(mut self, padding: Option<PaddingParams>) -> Self {
402 self.padding = padding;
403 self
404 }
405}
406
407#[derive(Serialize, Deserialize, Debug, Clone)]
408pub struct Tokenizer(
409 TokenizerImpl<
410 ModelWrapper,
411 NormalizerWrapper,
412 PreTokenizerWrapper,
413 PostProcessorWrapper,
414 DecoderWrapper,
415 >,
416);
417
418impl Tokenizer {
419 pub fn new(model: impl Into<ModelWrapper>) -> Self {
421 Self(TokenizerImpl::new(model.into()))
422 }
423
424 pub fn into_inner(
426 self,
427 ) -> TokenizerImpl<
428 ModelWrapper,
429 NormalizerWrapper,
430 PreTokenizerWrapper,
431 PostProcessorWrapper,
432 DecoderWrapper,
433 > {
434 self.0
435 }
436 pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
437 let content = read_to_string(file)?;
438 let tokenizer = serde_json::from_str(&content)?;
439 Ok(tokenizer)
440 }
441 pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> {
442 let tokenizer = serde_json::from_slice(bytes.as_ref())?;
443 Ok(tokenizer)
444 }
445 #[cfg(feature = "http")]
446 pub fn from_pretrained<S: AsRef<str>>(
447 identifier: S,
448 params: Option<crate::utils::from_pretrained::FromPretrainedParameters>,
449 ) -> Result<Self> {
450 let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?;
451 Tokenizer::from_file(tokenizer_file)
452 }
453}
454
455impl std::str::FromStr for Tokenizer {
456 type Err = Box<dyn std::error::Error + Send + Sync>;
457
458 fn from_str(s: &str) -> Result<Self> {
459 Ok(serde_json::from_str(s)?)
460 }
461}
462
463impl<M, N, PT, PP, D> From<TokenizerImpl<M, N, PT, PP, D>> for Tokenizer
464where
465 M: Into<ModelWrapper>,
466 N: Into<NormalizerWrapper>,
467 PT: Into<PreTokenizerWrapper>,
468 PP: Into<PostProcessorWrapper>,
469 D: Into<DecoderWrapper>,
470{
471 fn from(t: TokenizerImpl<M, N, PT, PP, D>) -> Self {
472 Self(TokenizerImpl {
473 model: t.model.into(),
474 normalizer: t.normalizer.map(Into::into),
475 pre_tokenizer: t.pre_tokenizer.map(Into::into),
476 post_processor: t.post_processor.map(Into::into),
477 decoder: t.decoder.map(Into::into),
478 added_vocabulary: t.added_vocabulary,
479 padding: t.padding,
480 truncation: t.truncation,
481 })
482 }
483}
484
485impl Deref for Tokenizer {
486 type Target = TokenizerImpl<
487 ModelWrapper,
488 NormalizerWrapper,
489 PreTokenizerWrapper,
490 PostProcessorWrapper,
491 DecoderWrapper,
492 >;
493
494 fn deref(&self) -> &Self::Target {
495 &self.0
496 }
497}
498
499impl DerefMut for Tokenizer {
500 fn deref_mut(&mut self) -> &mut Self::Target {
501 &mut self.0
502 }
503}
504
505#[derive(thiserror::Error, Debug)]
506#[error("{0}")]
507pub struct TruncationParamError(String);
508
509#[derive(Clone, Debug)]
511pub struct TokenizerImpl<M, N, PT, PP, D> {
512 normalizer: Option<N>,
514 pre_tokenizer: Option<PT>,
515 model: M,
516 post_processor: Option<PP>,
517 decoder: Option<D>,
518
519 added_vocabulary: AddedVocabulary,
521
522 truncation: Option<TruncationParams>,
524 padding: Option<PaddingParams>,
525}
526
527impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
528where
529 M: Model,
530 N: Normalizer,
531 PT: PreTokenizer,
532 PP: PostProcessor,
533 D: Decoder,
534{
535 pub fn new(model: M) -> Self {
537 Self {
538 normalizer: None,
539 pre_tokenizer: None,
540 model,
541 post_processor: None,
542 decoder: None,
543
544 added_vocabulary: AddedVocabulary::new(),
545
546 truncation: None,
547 padding: None,
548 }
549 }
550
551 pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
553 self.normalizer = normalizer.map(|norm| norm.into());
554 self
555 }
556 pub fn get_normalizer(&self) -> Option<&N> {
558 self.normalizer.as_ref()
559 }
560
561 pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
563 self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
564 self
565 }
566
567 pub fn get_pre_tokenizer(&self) -> Option<&PT> {
569 self.pre_tokenizer.as_ref()
570 }
571
572 pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
574 self.post_processor = post_processor.map(|post_proc| post_proc.into());
575 self
576 }
577
578 pub fn get_post_processor(&self) -> Option<&PP> {
580 self.post_processor.as_ref()
581 }
582
583 pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
585 self.decoder = decoder.map(|dec| dec.into());
586 self
587 }
588
589 pub fn get_decoder(&self) -> Option<&D> {
591 self.decoder.as_ref()
592 }
593
594 pub fn with_model(&mut self, model: impl Into<M>) -> &mut Self {
596 self.model = model.into();
597 self
598 }
599
600 pub fn get_model(&self) -> &M {
602 &self.model
603 }
604
605 pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self {
607 self.added_vocabulary = added_vocabulary;
608 self
609 }
610
611 pub fn get_added_vocabulary(&self) -> &AddedVocabulary {
613 &self.added_vocabulary
614 }
615
616 pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
620 if let Some(trunc_params) = &trunc {
621 let n_added_tokens = self.get_n_added_tokens(false);
622 let effective_max_length = trunc_params.max_length - n_added_tokens;
623 if effective_max_length < trunc_params.stride {
624 return Err(Box::new(TruncationParamError(format!(
625 "tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
626 trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
627 ))));
628 }
629 }
630 self.truncation = trunc;
631 Ok(self)
632 }
633
634 pub fn get_truncation(&self) -> Option<&TruncationParams> {
636 self.truncation.as_ref()
637 }
638
639 pub fn get_truncation_mut(&mut self) -> Option<&mut TruncationParams> {
641 self.truncation.as_mut()
642 }
643
644 pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &mut Self {
646 self.padding = padding;
647 self
648 }
649
650 pub fn get_padding(&self) -> Option<&PaddingParams> {
652 self.padding.as_ref()
653 }
654
655 pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> {
657 self.padding.as_mut()
658 }
659
660 pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
662 let mut final_vocab = self.model.get_vocab();
663
664 if with_added_tokens {
665 let added_vocab = self.added_vocabulary.get_vocab();
666 if !added_vocab.is_empty() {
667 final_vocab.reserve(added_vocab.len());
668 for (token, id) in added_vocab {
669 final_vocab.insert(token.clone(), *id);
670 }
671 }
672 }
673
674 final_vocab
675 }
676
677 pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
679 self.added_vocabulary.get_added_tokens_decoder().clone()
680 }
681
682 pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
684 if with_added_tokens {
687 self.get_vocab(true).len()
688 } else {
689 self.model.get_vocab_size()
690 }
691 }
692
693 pub fn token_to_id(&self, token: &str) -> Option<u32> {
695 self.added_vocabulary.token_to_id(token, &self.model)
696 }
697
698 pub fn id_to_token(&self, id: u32) -> Option<String> {
700 self.added_vocabulary
701 .simple_id_to_token(id)
702 .or_else(|| self.model.id_to_token(id))
703 }
704
705 pub fn set_encode_special_tokens(&mut self, value: bool) {
707 self.added_vocabulary.set_encode_special_tokens(value);
708 }
709
710 pub fn get_encode_special_tokens(&self) -> bool {
712 self.added_vocabulary.get_encode_special_tokens()
713 }
714
715 fn encode_single_sequence(
717 &self,
718 sequence: InputSequence,
719 type_id: u32,
720 offsets_type: OffsetType,
721 ) -> Result<Encoding> {
722 let encode = |is_pre_tokenized, subseq_idx, subseq| -> Result<Encoding> {
723 let normalized = self
724 .added_vocabulary
725 .extract_and_normalize(self.normalizer.as_ref(), subseq);
726 let pre_tokenized = self.do_pre_tokenize(normalized)?;
727 let subseq_encoding = self.do_tokenize(
728 pre_tokenized,
729 type_id,
730 if is_pre_tokenized {
731 Some(subseq_idx as u32)
732 } else {
733 None
734 },
735 offsets_type,
736 )?;
737
738 Ok(subseq_encoding)
739 };
740
741 match sequence {
742 InputSequence::PreTokenized(seq) => seq
743 .iter()
744 .enumerate()
745 .map(|(i, sequence)| encode(true, i, sequence))
746 .collect(),
747 InputSequence::PreTokenizedOwned(seq) => seq
748 .iter()
749 .enumerate()
750 .map(|(i, sequence)| encode(true, i, sequence))
751 .collect(),
752 InputSequence::PreTokenizedCow(seq) => seq
753 .iter()
754 .enumerate()
755 .map(|(i, sequence)| encode(true, i, sequence))
756 .collect(),
757 InputSequence::Raw(seq) => encode(false, 0, seq.as_ref()),
758 }
759 }
760
761 pub fn encode_fast<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
784 where
785 E: Into<EncodeInput<'s>>,
786 {
787 let (sequence, pair) = match input.into() {
789 EncodeInput::Single(s1) => (s1, None),
790 EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
791 };
792
793 let encoding = self.encode_single_sequence(sequence, 0, OffsetType::None)?;
795 let pair_encoding = pair
796 .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::None))
797 .transpose()?;
798
799 self.post_process(encoding, pair_encoding, add_special_tokens)
801 }
802
803 pub fn encode<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
826 where
827 E: Into<EncodeInput<'s>>,
828 {
829 let (sequence, pair) = match input.into() {
831 EncodeInput::Single(s1) => (s1, None),
832 EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
833 };
834
835 let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Byte)?;
837 let pair_encoding = pair
838 .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Byte))
839 .transpose()?;
840
841 self.post_process(encoding, pair_encoding, add_special_tokens)
843 }
844
845 pub fn encode_char_offsets<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding>
869 where
870 E: Into<EncodeInput<'s>>,
871 {
872 let (sequence, pair) = match input.into() {
874 EncodeInput::Single(s1) => (s1, None),
875 EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
876 };
877
878 let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Char)?;
880 let pair_encoding = pair
881 .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Char))
882 .transpose()?;
883
884 self.post_process(encoding, pair_encoding, add_special_tokens)
886 }
887
888 pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
890 let tokens = ids
891 .iter()
892 .filter_map(|id| {
893 self.added_vocabulary
894 .simple_id_to_token(*id)
895 .or_else(|| self.model.id_to_token(*id))
896 .filter(|token| {
897 !skip_special_tokens || !self.added_vocabulary.is_special_token(token)
898 })
899 })
900 .collect::<Vec<_>>();
901
902 if let Some(decoder) = &self.decoder {
903 decoder.decode(tokens)
904 } else {
905 Ok(tokens.join(" "))
906 }
907 }
908
909 pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream<'_, M, N, PT, PP, D> {
912 DecodeStream::new(self, skip_special_tokens)
913 }
914}
915
916pub struct DecodeStream<'tok, M, N, PT, PP, D> {
1016 tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
1018 skip_special_tokens: bool,
1020 ids: Vec<u32>,
1032 prefix: String,
1035 prefix_index: usize,
1038 read_index: usize,
1043}
1044
1045#[derive(thiserror::Error, Debug)]
1046pub enum DecodeStreamError {
1047 #[error("Invalid prefix encountered")]
1048 InvalidPrefix,
1049}
1050
1051impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
1052where
1053 M: Model,
1054 N: Normalizer,
1055 PT: PreTokenizer,
1056 PP: PostProcessor,
1057 D: Decoder,
1058{
1059 fn new(tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>, skip_special_tokens: bool) -> Self {
1060 Self {
1061 tokenizer,
1062 ids: vec![],
1063 skip_special_tokens,
1064 prefix: "".to_string(),
1065 prefix_index: 0,
1066 read_index: 0,
1067 }
1068 }
1069
1070 pub fn step(&mut self, id: u32) -> Result<Option<String>> {
1072 step_decode_stream(
1073 self.tokenizer,
1074 id,
1075 self.skip_special_tokens,
1076 &mut self.ids,
1077 &mut self.prefix,
1078 &mut self.prefix_index,
1079 &mut self.read_index,
1080 )
1081 }
1082}
1083
1084pub fn step_decode_stream<M, N, PT, PP, D>(
1086 tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
1087 id: u32,
1088 skip_special_tokens: bool,
1089 ids: &mut Vec<u32>,
1090 prefix: &mut String,
1091 prefix_index: &mut usize,
1092 read_index: &mut usize,
1093) -> Result<Option<String>>
1094where
1095 M: Model,
1096 N: Normalizer,
1097 PT: PreTokenizer,
1098 PP: PostProcessor,
1099 D: Decoder,
1100{
1101 ids.push(id);
1102 let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
1103 if string.len() > prefix.len() && !string.ends_with('�') {
1104 if !(string.starts_with(&*prefix)) {
1105 return Err(Box::new(DecodeStreamError::InvalidPrefix));
1106 }
1107 let new_text = &string[prefix.len()..].to_string();
1108 let new_prefix_index = ids.len() - *prefix_index;
1109 *ids = ids.drain(*read_index..).collect();
1110 *prefix = tokenizer.decode(ids, skip_special_tokens)?;
1111 *read_index = *prefix_index;
1112 *prefix_index = new_prefix_index;
1113 Ok(Some(new_text.to_string()))
1114 } else {
1115 Ok(None)
1116 }
1117}
1118
1119impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1120where
1121 M: Model,
1122{
1123 fn do_tokenize<P: Into<PreTokenizedString>>(
1126 &self,
1127 pretokenized: P,
1128 type_id: u32,
1129 word_idx: Option<u32>,
1130 offsets_type: OffsetType,
1131 ) -> Result<Encoding> {
1132 let mut pretokenized: PreTokenizedString = pretokenized.into();
1133 pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?;
1134 pretokenized.into_encoding(word_idx, type_id, offsets_type)
1135 }
1136}
1137
1138impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1139where
1140 N: Normalizer,
1141{
1142 fn do_normalize<V: Into<NormalizedString>>(&self, normalized: V) -> Result<NormalizedString> {
1144 let mut normalized: NormalizedString = normalized.into();
1145
1146 if let Some(ref normalizer) = self.normalizer {
1147 normalizer.normalize(&mut normalized)?;
1148 }
1149
1150 Ok(normalized)
1151 }
1152}
1153
1154impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1155where
1156 N: Normalizer,
1157 M: Model,
1158{
1159 pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
1162 self.added_vocabulary
1163 .add_special_tokens(tokens, &self.model, self.normalizer.as_ref())
1164 }
1165
1166 pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
1168 self.added_vocabulary
1169 .add_tokens(tokens, &self.model, self.normalizer.as_ref())
1170 }
1171}
1172
1173impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1174where
1175 PT: PreTokenizer,
1176{
1177 fn do_pre_tokenize<P: Into<PreTokenizedString>>(
1179 &self,
1180 pretokenized: P,
1181 ) -> Result<PreTokenizedString> {
1182 let mut pretokenized: PreTokenizedString = pretokenized.into();
1183 if let Some(ref pretok) = self.pre_tokenizer {
1184 pretok.pre_tokenize(&mut pretokenized)?;
1185 }
1186
1187 Ok(pretokenized)
1188 }
1189}
1190
1191impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1192where
1193 PP: PostProcessor,
1194{
1195 pub fn post_process(
1197 &self,
1198 encoding: Encoding,
1199 pair_encoding: Option<Encoding>,
1200 add_special_tokens: bool,
1201 ) -> Result<Encoding> {
1202 let (encoding, pair_encoding) = {
1204 if let Some(trunc) = &self.truncation {
1205 let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());
1206
1207 if add_special_tokens && n_added_tokens > 0 {
1208 let params = TruncationParams {
1209 max_length: trunc.max_length - n_added_tokens,
1210 ..*trunc
1211 };
1212 truncate_encodings(encoding, pair_encoding, ¶ms)?
1213 } else {
1214 truncate_encodings(encoding, pair_encoding, trunc)?
1215 }
1216 } else {
1217 (encoding, pair_encoding)
1218 }
1219 };
1220
1221 let final_encoding = if let Some(processor) = &self.post_processor {
1223 processor.process(encoding, pair_encoding, add_special_tokens)?
1224 } else {
1225 let encodings = if let Some(pair_encoding) = pair_encoding {
1226 vec![encoding, pair_encoding]
1227 } else {
1228 vec![encoding]
1229 };
1230 let mut encodings =
1231 <dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
1232 if encodings.len() != 1 {
1233 panic!("We haven't reduced the encodings like we should have");
1234 }
1235 encodings.pop().unwrap()
1236 };
1237
1238 let [final_encoding] = if let Some(params) = &self.padding {
1240 let mut arr = [final_encoding];
1241 pad_encodings(&mut arr, params)?;
1242 arr
1243 } else {
1244 [final_encoding]
1245 };
1246
1247 Ok(final_encoding)
1248 }
1249
1250 fn get_n_added_tokens(&self, is_pair: bool) -> usize {
1251 if let Some(processor) = &self.post_processor {
1252 processor.added_tokens(is_pair)
1253 } else {
1254 0
1255 }
1256 }
1257}
1258
1259impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1260where
1261 M: Model + Send + Sync,
1262 N: Normalizer + Send + Sync,
1263 PT: PreTokenizer + Send + Sync,
1264 PP: PostProcessor + Send + Sync,
1265 D: Decoder + Send + Sync,
1266{
1267 pub fn encode_batch<'s, E>(
1269 &self,
1270 inputs: Vec<E>,
1271 add_special_tokens: bool,
1272 ) -> Result<Vec<Encoding>>
1273 where
1274 E: Into<EncodeInput<'s>> + Send,
1275 {
1276 let mut encodings = inputs
1277 .into_maybe_par_iter()
1278 .map(|input| self.encode(input, add_special_tokens))
1279 .collect::<Result<Vec<Encoding>>>()?;
1280
1281 if let Some(params) = &self.padding {
1282 pad_encodings(&mut encodings, params)?;
1284 }
1285
1286 Ok(encodings)
1287 }
1288
1289 pub fn encode_batch_char_offsets<'s, E>(
1292 &self,
1293 inputs: Vec<E>,
1294 add_special_tokens: bool,
1295 ) -> Result<Vec<Encoding>>
1296 where
1297 E: Into<EncodeInput<'s>> + Send,
1298 {
1299 let mut encodings = inputs
1300 .into_maybe_par_iter()
1301 .map(|input| self.encode_char_offsets(input, add_special_tokens))
1302 .collect::<Result<Vec<Encoding>>>()?;
1303
1304 if let Some(params) = &self.padding {
1305 pad_encodings(&mut encodings, params)?;
1307 }
1308
1309 Ok(encodings)
1310 }
1311
1312 pub fn encode_batch_fast<'s, E>(
1314 &self,
1315 inputs: Vec<E>,
1316 add_special_tokens: bool,
1317 ) -> Result<Vec<Encoding>>
1318 where
1319 E: Into<EncodeInput<'s>> + Send,
1320 {
1321 let mut encodings = inputs
1322 .into_maybe_par_iter()
1323 .map(|input| self.encode_fast(input, add_special_tokens))
1324 .collect::<Result<Vec<Encoding>>>()?;
1325
1326 if let Some(params) = &self.padding {
1327 pad_encodings(&mut encodings, params)?;
1329 }
1330
1331 Ok(encodings)
1332 }
1333
1334 pub fn decode_batch(
1336 &self,
1337 sentences: &[&[u32]],
1338 skip_special_tokens: bool,
1339 ) -> Result<Vec<String>>
1340 where
1341 M: Send + Sync,
1342 {
1343 sentences
1344 .into_maybe_par_iter()
1345 .map(|sentence| self.decode(sentence, skip_special_tokens))
1346 .collect()
1347 }
1348
1349 pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
1351 where
1352 T: Trainer<Model = M> + Sync,
1353 {
1354 let mut len = 0;
1355 for file in files.iter() {
1356 len += File::open(file)
1357 .and_then(|f| f.metadata())
1358 .map(|m| m.len())?;
1359 }
1360
1361 let max_read = 1_000_000;
1362
1363 ResultShunt::process(
1364 files.into_iter().flat_map(|filename| {
1365 match File::open(filename) {
1366 Ok(file) => {
1367 let file = BufReader::with_capacity(max_read, file);
1368 itertools::Either::Left(file.lines_with_ending())
1372 }
1373 Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
1374 }
1375 }),
1376 |sequences| -> Result<()> {
1377 let progress = if trainer.should_show_progress() {
1378 let progress = ProgressBar::new(len);
1379 progress.set_style(
1380 ProgressStyle::default_bar()
1381 .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {percent:>18!}%")
1382 .expect("Invalid progress template"),
1383 );
1384 progress
1385 .set_message(format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
1386 Some(progress)
1387 } else {
1388 None
1389 };
1390
1391 trainer.feed(
1392 sequences.inspect(|s| {
1393 if let Some(progress) = &progress {
1394 progress.inc(s.len() as u64)
1395 }
1396 }),
1397 |seq| {
1398 let normalized = self.do_normalize(seq.as_ref())?;
1399 let pre_tokenized = self.do_pre_tokenize(normalized)?;
1400 Ok(pre_tokenized
1401 .get_splits(OffsetReferential::Original, OffsetType::Byte)
1402 .into_iter()
1403 .map(|(s, _, _)| s.to_owned())
1404 .collect())
1405 },
1406 )?;
1407
1408 if let Some(pbar) = progress {
1409 pbar.finish();
1410 }
1411 let special_tokens = trainer.train(&mut self.model)?;
1412 self.add_special_tokens(&special_tokens);
1413
1414 Ok(())
1415 },
1416 )??;
1417 Ok(self)
1418 }
1419
1420 pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
1422 where
1423 T: Trainer<Model = M> + Sync,
1424 I: Iterator<Item = S> + Send,
1425 S: AsRef<str> + Send,
1426 {
1427 let (lower, upper) = sequences.size_hint();
1428 let len = upper.unwrap_or(lower) as u64;
1429 let progress = if trainer.should_show_progress() {
1430 let progress = ProgressBar::new(len);
1431 progress.set_style(
1432 ProgressStyle::default_bar()
1433 .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
1434 .expect("Invalid progress template"),
1435 );
1436 progress.set_message("Pre-processing sequences");
1437 Some(progress)
1438 } else {
1439 None
1440 };
1441
1442 trainer.feed(
1443 sequences.inspect(|_s| {
1444 if let Some(progress) = &progress {
1445 progress.inc(1)
1446 }
1447 }),
1448 |seq| {
1449 let normalized = self.do_normalize(seq.as_ref())?;
1450 let pre_tokenized = self.do_pre_tokenize(normalized)?;
1451 Ok(pre_tokenized
1452 .get_splits(OffsetReferential::Original, OffsetType::Byte)
1453 .into_iter()
1454 .map(|(s, _, _)| s.to_owned())
1455 .collect())
1456 },
1457 )?;
1458 if let Some(pbar) = progress {
1459 pbar.finish();
1460 }
1461
1462 let special_tokens = trainer.train(&mut self.model)?;
1463 self.add_special_tokens(&special_tokens);
1464
1465 Ok(self)
1466 }
1467}
1468
1469impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D>
1470where
1471 M: for<'de> Deserialize<'de> + Model,
1472 N: for<'de> Deserialize<'de> + Normalizer,
1473 PT: for<'de> Deserialize<'de> + PreTokenizer,
1474 PP: for<'de> Deserialize<'de> + PostProcessor,
1475 D: for<'de> Deserialize<'de> + Decoder,
1476{
1477 type Err = Error;
1478
1479 fn from_str(s: &str) -> Result<Self> {
1480 Ok(serde_json::from_str(s)?)
1481 }
1482}
1483
1484impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1485where
1486 M: DeserializeOwned + Model,
1487 N: DeserializeOwned + Normalizer,
1488 PT: DeserializeOwned + PreTokenizer,
1489 PP: DeserializeOwned + PostProcessor,
1490 D: DeserializeOwned + Decoder,
1491{
1492 pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
1494 let content = read_to_string(file)?;
1495 let tokenizer = serde_json::from_str(&content)?;
1496 Ok(tokenizer)
1497 }
1498}
1499
1500impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1501where
1502 M: DeserializeOwned + Model,
1503 N: DeserializeOwned + Normalizer,
1504 PT: DeserializeOwned + PreTokenizer,
1505 PP: DeserializeOwned + PostProcessor,
1506 D: DeserializeOwned + Decoder,
1507{
1508 pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> {
1510 let tokenizer = serde_json::from_slice(bytes.as_ref())?;
1511 Ok(tokenizer)
1512 }
1513}
1514
1515impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1516where
1517 M: DeserializeOwned + Model,
1518 N: DeserializeOwned + Normalizer,
1519 PT: DeserializeOwned + PreTokenizer,
1520 PP: DeserializeOwned + PostProcessor,
1521 D: DeserializeOwned + Decoder,
1522{
1523 #[deprecated(
1524 since = "0.14.0",
1525 note = "Users should download the file separately using https://github.com/huggingface/hf-hub instead, which splits concerns of accessing the web, and should use the new cache layout"
1526 )]
1527 #[cfg(feature = "http")]
1528 pub fn from_pretrained<S: AsRef<str>>(
1531 identifier: S,
1532 params: Option<crate::utils::from_pretrained::FromPretrainedParameters>,
1533 ) -> Result<Self> {
1534 let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?;
1535 TokenizerImpl::from_file(tokenizer_file)
1536 }
1537}
1538
1539impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
1540where
1541 M: Serialize,
1542 N: Serialize,
1543 PT: Serialize,
1544 PP: Serialize,
1545 D: Serialize,
1546{
1547 pub fn to_string(&self, pretty: bool) -> Result<String> {
1549 Ok(if pretty {
1550 serde_json::to_string_pretty(self)?
1551 } else {
1552 serde_json::to_string(self)?
1553 })
1554 }
1555
1556 pub fn save<P: AsRef<Path>>(&self, path: P, pretty: bool) -> Result<()> {
1558 let serialized = self.to_string(pretty)?;
1559
1560 let mut file = File::create(path)?;
1561 file.write_all(serialized.as_bytes())?;
1562
1563 Ok(())
1564 }
1565}
1566
1567#[cfg(test)]
1568mod test {
1569 #[cfg(feature = "http")]
1570 #[test]
1571 fn test_decoding_with_added_bpe() {
1572 use crate::{
1573 normalizers,
1574 pre_tokenizers::split::{Split, SplitPattern},
1575 AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
1576 };
1577
1578 let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
1579 tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
1580 tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
1581 Split::new(
1582 SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
1583 SplitDelimiterBehavior::Isolated,
1584 false,
1585 )
1586 .unwrap(),
1587 ));
1588 tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
1589 let encoded = tokenizer
1590 .encode("Hey! how is this token: 嗎", false)
1591 .unwrap();
1592 assert_eq!(
1593 encoded.get_ids(),
1594 [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
1595 );
1596 assert_eq!(
1597 encoded.get_tokens(),
1598 ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
1599 );
1600
1601 let decoded = tokenizer.decode(encoded.get_ids(), false);
1602 assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
1603
1604 tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
1605 let encoded = tokenizer
1606 .encode("Hey! how is this token: д", false)
1607 .unwrap();
1608 assert_eq!(
1609 encoded.get_ids(),
1610 [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
1611 );
1612 assert_eq!(
1613 encoded.get_tokens(),
1614 ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
1615 );
1616 let decoded = tokenizer.decode(encoded.get_ids(), false);
1617 assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
1618 }
1619}