text_splitter/chunk_size/
rust_tokenizers.rs

1use rust_tokenizers::{
2    tokenizer::{
3        AlbertTokenizer, BaseTokenizer, BertTokenizer, CtrlTokenizer, DeBERTaTokenizer,
4        DeBERTaV2Tokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer,
5        MarianTokenizer, NLLBTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer,
6        ReformerTokenizer, RobertaTokenizer, SentencePieceBpeTokenizer, SentencePieceTokenizer,
7        T5Tokenizer, Tokenizer, XLMRobertaTokenizer, XLNetTokenizer,
8    },
9    vocab::Vocab,
10};
11
12use crate::ChunkSizer;
13
14fn chunk_size_from_offsets<V: Vocab, T: Tokenizer<V>>(tokenizer: &T, chunk: &str) -> usize {
15    tokenizer.tokenize(chunk).len()
16}
17
18impl<V> ChunkSizer for &BaseTokenizer<V>
19where
20    V: Vocab + Sync + Send,
21{
22    fn size(&self, chunk: &str) -> usize {
23        chunk_size_from_offsets(*self, chunk)
24    }
25}
26
27impl<V> ChunkSizer for BaseTokenizer<V>
28where
29    V: Vocab + Sync + Send,
30{
31    fn size(&self, chunk: &str) -> usize {
32        (&self).size(chunk)
33    }
34}
35
36macro_rules! impl_chunk_sizer {
37    ($($t:ty),+) => {
38        $(impl ChunkSizer for &$t {
39            fn size(&self, chunk: &str) -> usize {
40                chunk_size_from_offsets(*self, chunk)
41            }
42        }
43
44        impl ChunkSizer for $t {
45            fn size(&self, chunk: &str) -> usize {
46                (&self).size(chunk)
47            }
48        })+
49    }
50}
51
52impl_chunk_sizer!(
53    AlbertTokenizer,
54    BertTokenizer,
55    CtrlTokenizer,
56    DeBERTaTokenizer,
57    DeBERTaV2Tokenizer,
58    FNetTokenizer,
59    Gpt2Tokenizer,
60    M2M100Tokenizer,
61    MBart50Tokenizer,
62    MarianTokenizer,
63    NLLBTokenizer,
64    OpenAiGptTokenizer,
65    PegasusTokenizer,
66    ProphetNetTokenizer,
67    ReformerTokenizer,
68    RobertaTokenizer,
69    SentencePieceBpeTokenizer,
70    SentencePieceTokenizer,
71    T5Tokenizer,
72    XLMRobertaTokenizer,
73    XLNetTokenizer
74);
75
76#[cfg(test)]
77mod tests {
78    use std::path::PathBuf;
79
80    use cached_path::Cache;
81    use rayon::prelude::*;
82    use rust_tokenizers::vocab::{BertVocab, BpePairVocab, Gpt2Vocab, ProphetNetVocab};
83    use strum::{EnumIter, IntoEnumIterator};
84
85    use super::*;
86
87    /// Downloads a remote file to the cache directory if it doensn't already exist,
88    /// and returns the path to the cached file.
89    fn download_file_to_cache(src: &str) -> PathBuf {
90        let mut cache_dir = dirs::home_dir().unwrap();
91        cache_dir.push(".cache");
92        cache_dir.push(".text-splitter");
93
94        Cache::builder()
95            .dir(cache_dir)
96            .build()
97            .unwrap()
98            .cached_path(src)
99            .unwrap()
100    }
101
102    #[test]
103    fn returns_offsets() {
104        let vocab_path = download_file_to_cache(
105            "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
106        );
107        let tokenizer = BertTokenizer::from_file(vocab_path, false, false).unwrap();
108        let size = tokenizer.size(" An apple a");
109        assert_eq!(size, 3);
110    }
111
112    #[test]
113    fn smoke_test() {
114        let sizes = TokenizerOption::iter()
115            .collect::<Vec<_>>()
116            .into_par_iter()
117            .map(|tokenizer| tokenizer.tokenizer().size(" An apple a"));
118        assert!(sizes.all(|size| size > 0));
119    }
120
121    #[derive(EnumIter)]
122    enum TokenizerOption {
123        Albert,
124        Base,
125        Bert,
126        Ctrl,
127        DeBERTa,
128        DeBERTaV2,
129        FNet,
130        Gpt2,
131        M2M100,
132        MBart50,
133        // Marian, // No example source vocab at the moment
134        Nllb,
135        OpenAiGpt,
136        Pegasus,
137        ProphetNet,
138        Reformer,
139        Roberta,
140        SentencePieceBpe,
141        SentencePiece,
142        T5,
143        XLMRoberta,
144        XLNet,
145    }
146
147    impl TokenizerOption {
148        #[allow(clippy::too_many_lines)]
149        fn tokenizer(&self) -> Box<dyn ChunkSizer> {
150            match self {
151                Self::Albert => {
152                    let vocab_path = download_file_to_cache(
153                        "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
154                    );
155                    Box::new(AlbertTokenizer::from_file(vocab_path, false, false).unwrap())
156                }
157                Self::Base => {
158                    let vocab_path = download_file_to_cache(
159                        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
160                    );
161                    let vocab = BertVocab::from_file(vocab_path).unwrap();
162                    Box::new(BaseTokenizer::from_existing_vocab(vocab, false, false))
163                }
164                Self::Bert => {
165                    let vocab_path = download_file_to_cache(
166                        "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
167                    );
168                    Box::new(BertTokenizer::from_file(vocab_path, false, false).unwrap())
169                }
170                Self::Ctrl => {
171                    let vocab_path = download_file_to_cache(
172                        "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
173                    );
174                    let merges_path = download_file_to_cache(
175                        "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
176                    );
177                    Box::new(CtrlTokenizer::from_file(vocab_path, merges_path, false).unwrap())
178                }
179                Self::DeBERTa => {
180                    let vocab_path = download_file_to_cache(
181                        "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
182                    );
183                    let merges_path = download_file_to_cache(
184                        "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
185                    );
186                    Box::new(DeBERTaTokenizer::from_file(vocab_path, merges_path, false).unwrap())
187                }
188                Self::DeBERTaV2 => {
189                    let vocab_path = download_file_to_cache(
190                        "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model",
191                    );
192                    Box::new(
193                        DeBERTaV2Tokenizer::from_file(vocab_path, false, false, false).unwrap(),
194                    )
195                }
196                Self::FNet => {
197                    let vocab_path = download_file_to_cache(
198                        "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
199                    );
200                    Box::new(FNetTokenizer::from_file(vocab_path, false, false).unwrap())
201                }
202                Self::Gpt2 => {
203                    let vocab_path = download_file_to_cache(
204                        "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
205                    );
206                    let merges_path = download_file_to_cache(
207                        "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
208                    );
209                    let vocab = Gpt2Vocab::from_file(vocab_path.as_path()).unwrap();
210                    let merges = BpePairVocab::from_file(merges_path.as_path()).unwrap();
211
212                    Box::new(Gpt2Tokenizer::from_existing_vocab_and_merges(
213                        vocab, merges, false,
214                    ))
215                }
216                Self::M2M100 => {
217                    let vocab_path = download_file_to_cache(
218                        "https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json",
219                    );
220                    let merges_path = download_file_to_cache(
221                        "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
222                    );
223
224                    Box::new(M2M100Tokenizer::from_files(vocab_path, merges_path, false).unwrap())
225                }
226                Self::MBart50 => {
227                    let vocab_path = download_file_to_cache(
228                        "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model",
229                    );
230
231                    Box::new(MBart50Tokenizer::from_file(vocab_path, false).unwrap())
232                }
233                Self::Nllb => {
234                    let vocab_path = download_file_to_cache(
235                        "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/tokenizer.json",
236                    );
237                    let merges_path = download_file_to_cache(
238                        "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/sentencepiece.bpe.model",
239                    );
240                    let special_path = download_file_to_cache(
241                        "https://huggingface.co/facebook/nllb-200-distilled-600M/raw/main/special_tokens_map.json",
242                    );
243
244                    Box::new(
245                        NLLBTokenizer::from_files_with_special_token_map(
246                            vocab_path,
247                            merges_path,
248                            special_path,
249                        )
250                        .unwrap(),
251                    )
252                }
253                Self::OpenAiGpt => {
254                    let vocab_path = download_file_to_cache(
255                        "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
256                    );
257                    let merges_path = download_file_to_cache(
258                        "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
259                    );
260
261                    Box::new(OpenAiGptTokenizer::from_file(vocab_path, merges_path, true).unwrap())
262                }
263                Self::Pegasus => {
264                    let vocab_path = download_file_to_cache(
265                        "https://cdn.huggingface.co/google/pegasus-cnn_dailymail/spiece.model",
266                    );
267
268                    Box::new(PegasusTokenizer::from_file(vocab_path, false).unwrap())
269                }
270                Self::ProphetNet => {
271                    let vocab_path = download_file_to_cache(
272                        "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
273                    );
274                    let vocab = ProphetNetVocab::from_file(vocab_path).unwrap();
275
276                    Box::new(ProphetNetTokenizer::from_existing_vocab(vocab, true, true))
277                }
278                Self::Reformer => {
279                    let vocab_path = download_file_to_cache(
280                        "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model",
281                    );
282
283                    Box::new(ReformerTokenizer::from_file(vocab_path, false).unwrap())
284                }
285                Self::Roberta => {
286                    let vocab_path = download_file_to_cache(
287                        "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
288                    );
289                    let merges_path = download_file_to_cache(
290                        "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
291                    );
292
293                    Box::new(
294                        RobertaTokenizer::from_file(vocab_path, merges_path, false, true).unwrap(),
295                    )
296                }
297                Self::SentencePieceBpe => {
298                    let vocab_path = download_file_to_cache(
299                        "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
300                    );
301
302                    Box::new(SentencePieceBpeTokenizer::from_file(vocab_path, false).unwrap())
303                }
304
305                Self::SentencePiece => {
306                    let vocab_path = download_file_to_cache(
307                        "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
308                    );
309
310                    Box::new(SentencePieceTokenizer::from_file(vocab_path, false).unwrap())
311                }
312                Self::T5 => {
313                    let vocab_path = download_file_to_cache(
314                        "https://huggingface.co/t5-base/resolve/main/spiece.model",
315                    );
316
317                    Box::new(T5Tokenizer::from_file(vocab_path, false).unwrap())
318                }
319                Self::XLMRoberta => {
320                    let vocab_path = download_file_to_cache("https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model");
321
322                    Box::new(XLMRobertaTokenizer::from_file(vocab_path, false).unwrap())
323                }
324                Self::XLNet => {
325                    let vocab_path = download_file_to_cache(
326                        "https://cdn.huggingface.co/xlnet-base-cased-spiece.model",
327                    );
328
329                    Box::new(XLNetTokenizer::from_file(vocab_path, false, true).unwrap())
330                }
331            }
332        }
333    }
334}