text_splitter/chunk_size/
rust_tokenizers.rs1use rust_tokenizers::{
2 tokenizer::{
3 AlbertTokenizer, BaseTokenizer, BertTokenizer, CtrlTokenizer, DeBERTaTokenizer,
4 DeBERTaV2Tokenizer, FNetTokenizer, Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer,
5 MarianTokenizer, NLLBTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer,
6 ReformerTokenizer, RobertaTokenizer, SentencePieceBpeTokenizer, SentencePieceTokenizer,
7 T5Tokenizer, Tokenizer, XLMRobertaTokenizer, XLNetTokenizer,
8 },
9 vocab::Vocab,
10};
11
12use crate::ChunkSizer;
13
14fn chunk_size_from_offsets<V: Vocab, T: Tokenizer<V>>(tokenizer: &T, chunk: &str) -> usize {
15 tokenizer.tokenize(chunk).len()
16}
17
18impl<V> ChunkSizer for &BaseTokenizer<V>
19where
20 V: Vocab + Sync + Send,
21{
22 fn size(&self, chunk: &str) -> usize {
23 chunk_size_from_offsets(*self, chunk)
24 }
25}
26
27impl<V> ChunkSizer for BaseTokenizer<V>
28where
29 V: Vocab + Sync + Send,
30{
31 fn size(&self, chunk: &str) -> usize {
32 (&self).size(chunk)
33 }
34}
35
36macro_rules! impl_chunk_sizer {
37 ($($t:ty),+) => {
38 $(impl ChunkSizer for &$t {
39 fn size(&self, chunk: &str) -> usize {
40 chunk_size_from_offsets(*self, chunk)
41 }
42 }
43
44 impl ChunkSizer for $t {
45 fn size(&self, chunk: &str) -> usize {
46 (&self).size(chunk)
47 }
48 })+
49 }
50}
51
52impl_chunk_sizer!(
53 AlbertTokenizer,
54 BertTokenizer,
55 CtrlTokenizer,
56 DeBERTaTokenizer,
57 DeBERTaV2Tokenizer,
58 FNetTokenizer,
59 Gpt2Tokenizer,
60 M2M100Tokenizer,
61 MBart50Tokenizer,
62 MarianTokenizer,
63 NLLBTokenizer,
64 OpenAiGptTokenizer,
65 PegasusTokenizer,
66 ProphetNetTokenizer,
67 ReformerTokenizer,
68 RobertaTokenizer,
69 SentencePieceBpeTokenizer,
70 SentencePieceTokenizer,
71 T5Tokenizer,
72 XLMRobertaTokenizer,
73 XLNetTokenizer
74);
75
76#[cfg(test)]
77mod tests {
78 use std::path::PathBuf;
79
80 use cached_path::Cache;
81 use rayon::prelude::*;
82 use rust_tokenizers::vocab::{BertVocab, BpePairVocab, Gpt2Vocab, ProphetNetVocab};
83 use strum::{EnumIter, IntoEnumIterator};
84
85 use super::*;
86
87 fn download_file_to_cache(src: &str) -> PathBuf {
90 let mut cache_dir = dirs::home_dir().unwrap();
91 cache_dir.push(".cache");
92 cache_dir.push(".text-splitter");
93
94 Cache::builder()
95 .dir(cache_dir)
96 .build()
97 .unwrap()
98 .cached_path(src)
99 .unwrap()
100 }
101
102 #[test]
103 fn returns_offsets() {
104 let vocab_path = download_file_to_cache(
105 "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
106 );
107 let tokenizer = BertTokenizer::from_file(vocab_path, false, false).unwrap();
108 let size = tokenizer.size(" An apple a");
109 assert_eq!(size, 3);
110 }
111
112 #[test]
113 fn smoke_test() {
114 let sizes = TokenizerOption::iter()
115 .collect::<Vec<_>>()
116 .into_par_iter()
117 .map(|tokenizer| tokenizer.tokenizer().size(" An apple a"));
118 assert!(sizes.all(|size| size > 0));
119 }
120
121 #[derive(EnumIter)]
122 enum TokenizerOption {
123 Albert,
124 Base,
125 Bert,
126 Ctrl,
127 DeBERTa,
128 DeBERTaV2,
129 FNet,
130 Gpt2,
131 M2M100,
132 MBart50,
133 Nllb,
135 OpenAiGpt,
136 Pegasus,
137 ProphetNet,
138 Reformer,
139 Roberta,
140 SentencePieceBpe,
141 SentencePiece,
142 T5,
143 XLMRoberta,
144 XLNet,
145 }
146
147 impl TokenizerOption {
148 #[allow(clippy::too_many_lines)]
149 fn tokenizer(&self) -> Box<dyn ChunkSizer> {
150 match self {
151 Self::Albert => {
152 let vocab_path = download_file_to_cache(
153 "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
154 );
155 Box::new(AlbertTokenizer::from_file(vocab_path, false, false).unwrap())
156 }
157 Self::Base => {
158 let vocab_path = download_file_to_cache(
159 "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
160 );
161 let vocab = BertVocab::from_file(vocab_path).unwrap();
162 Box::new(BaseTokenizer::from_existing_vocab(vocab, false, false))
163 }
164 Self::Bert => {
165 let vocab_path = download_file_to_cache(
166 "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
167 );
168 Box::new(BertTokenizer::from_file(vocab_path, false, false).unwrap())
169 }
170 Self::Ctrl => {
171 let vocab_path = download_file_to_cache(
172 "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
173 );
174 let merges_path = download_file_to_cache(
175 "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
176 );
177 Box::new(CtrlTokenizer::from_file(vocab_path, merges_path, false).unwrap())
178 }
179 Self::DeBERTa => {
180 let vocab_path = download_file_to_cache(
181 "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
182 );
183 let merges_path = download_file_to_cache(
184 "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
185 );
186 Box::new(DeBERTaTokenizer::from_file(vocab_path, merges_path, false).unwrap())
187 }
188 Self::DeBERTaV2 => {
189 let vocab_path = download_file_to_cache(
190 "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model",
191 );
192 Box::new(
193 DeBERTaV2Tokenizer::from_file(vocab_path, false, false, false).unwrap(),
194 )
195 }
196 Self::FNet => {
197 let vocab_path = download_file_to_cache(
198 "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
199 );
200 Box::new(FNetTokenizer::from_file(vocab_path, false, false).unwrap())
201 }
202 Self::Gpt2 => {
203 let vocab_path = download_file_to_cache(
204 "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
205 );
206 let merges_path = download_file_to_cache(
207 "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
208 );
209 let vocab = Gpt2Vocab::from_file(vocab_path.as_path()).unwrap();
210 let merges = BpePairVocab::from_file(merges_path.as_path()).unwrap();
211
212 Box::new(Gpt2Tokenizer::from_existing_vocab_and_merges(
213 vocab, merges, false,
214 ))
215 }
216 Self::M2M100 => {
217 let vocab_path = download_file_to_cache(
218 "https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json",
219 );
220 let merges_path = download_file_to_cache(
221 "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
222 );
223
224 Box::new(M2M100Tokenizer::from_files(vocab_path, merges_path, false).unwrap())
225 }
226 Self::MBart50 => {
227 let vocab_path = download_file_to_cache(
228 "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model",
229 );
230
231 Box::new(MBart50Tokenizer::from_file(vocab_path, false).unwrap())
232 }
233 Self::Nllb => {
234 let vocab_path = download_file_to_cache(
235 "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/tokenizer.json",
236 );
237 let merges_path = download_file_to_cache(
238 "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/sentencepiece.bpe.model",
239 );
240 let special_path = download_file_to_cache(
241 "https://huggingface.co/facebook/nllb-200-distilled-600M/raw/main/special_tokens_map.json",
242 );
243
244 Box::new(
245 NLLBTokenizer::from_files_with_special_token_map(
246 vocab_path,
247 merges_path,
248 special_path,
249 )
250 .unwrap(),
251 )
252 }
253 Self::OpenAiGpt => {
254 let vocab_path = download_file_to_cache(
255 "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
256 );
257 let merges_path = download_file_to_cache(
258 "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
259 );
260
261 Box::new(OpenAiGptTokenizer::from_file(vocab_path, merges_path, true).unwrap())
262 }
263 Self::Pegasus => {
264 let vocab_path = download_file_to_cache(
265 "https://cdn.huggingface.co/google/pegasus-cnn_dailymail/spiece.model",
266 );
267
268 Box::new(PegasusTokenizer::from_file(vocab_path, false).unwrap())
269 }
270 Self::ProphetNet => {
271 let vocab_path = download_file_to_cache(
272 "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
273 );
274 let vocab = ProphetNetVocab::from_file(vocab_path).unwrap();
275
276 Box::new(ProphetNetTokenizer::from_existing_vocab(vocab, true, true))
277 }
278 Self::Reformer => {
279 let vocab_path = download_file_to_cache(
280 "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model",
281 );
282
283 Box::new(ReformerTokenizer::from_file(vocab_path, false).unwrap())
284 }
285 Self::Roberta => {
286 let vocab_path = download_file_to_cache(
287 "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
288 );
289 let merges_path = download_file_to_cache(
290 "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
291 );
292
293 Box::new(
294 RobertaTokenizer::from_file(vocab_path, merges_path, false, true).unwrap(),
295 )
296 }
297 Self::SentencePieceBpe => {
298 let vocab_path = download_file_to_cache(
299 "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
300 );
301
302 Box::new(SentencePieceBpeTokenizer::from_file(vocab_path, false).unwrap())
303 }
304
305 Self::SentencePiece => {
306 let vocab_path = download_file_to_cache(
307 "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
308 );
309
310 Box::new(SentencePieceTokenizer::from_file(vocab_path, false).unwrap())
311 }
312 Self::T5 => {
313 let vocab_path = download_file_to_cache(
314 "https://huggingface.co/t5-base/resolve/main/spiece.model",
315 );
316
317 Box::new(T5Tokenizer::from_file(vocab_path, false).unwrap())
318 }
319 Self::XLMRoberta => {
320 let vocab_path = download_file_to_cache("https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model");
321
322 Box::new(XLMRobertaTokenizer::from_file(vocab_path, false).unwrap())
323 }
324 Self::XLNet => {
325 let vocab_path = download_file_to_cache(
326 "https://cdn.huggingface.co/xlnet-base-cased-spiece.model",
327 );
328
329 Box::new(XLNetTokenizer::from_file(vocab_path, false, true).unwrap())
330 }
331 }
332 }
333 }
334}