1use scirs2_core::parallel_ops::*; use std::sync::Arc;
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6pub struct ParallelTokenizer<T: Tokenizer + Sync> {
8 tokenizer: Arc<T>,
9 chunk_size: usize,
10}
11
12impl<T: Tokenizer + Sync> ParallelTokenizer<T> {
13 pub fn new(tokenizer: T) -> Self {
15 Self {
16 tokenizer: Arc::new(tokenizer),
17 chunk_size: 1000, }
19 }
20
21 pub fn with_chunk_size(tokenizer: T, chunk_size: usize) -> Self {
23 Self {
24 tokenizer: Arc::new(tokenizer),
25 chunk_size,
26 }
27 }
28
29 pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
31 texts
32 .par_chunks(self.chunk_size)
33 .map(|chunk| {
34 chunk.iter().map(|text| self.tokenizer.encode(text)).collect::<Result<Vec<_>>>()
35 })
36 .collect::<Result<Vec<Vec<_>>>>()
37 .map(|batches| batches.into_iter().flatten().collect())
38 }
39
40 pub fn encode_pair_batch(&self, text_pairs: &[(&str, &str)]) -> Result<Vec<TokenizedInput>> {
42 text_pairs
43 .par_chunks(self.chunk_size)
44 .map(|chunk| {
45 chunk
46 .iter()
47 .map(|(text1, text2)| self.tokenizer.encode_pair(text1, text2))
48 .collect::<Result<Vec<_>>>()
49 })
50 .collect::<Result<Vec<Vec<_>>>>()
51 .map(|batches| batches.into_iter().flatten().collect())
52 }
53
54 pub fn decode_batch(&self, ids_batch: &[&[u32]]) -> Result<Vec<String>> {
56 ids_batch
57 .par_chunks(self.chunk_size)
58 .map(|chunk| {
59 chunk.iter().map(|ids| self.tokenizer.decode(ids)).collect::<Result<Vec<_>>>()
60 })
61 .collect::<Result<Vec<Vec<_>>>>()
62 .map(|batches| batches.into_iter().flatten().collect())
63 }
64
65 pub fn tokenizer(&self) -> &T {
67 &self.tokenizer
68 }
69
70 pub fn chunk_size(&self) -> usize {
72 self.chunk_size
73 }
74
75 pub fn set_chunk_size(&mut self, chunk_size: usize) {
77 self.chunk_size = chunk_size;
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct BatchTokenizer<T: Tokenizer + Sync> {
84 tokenizer: Arc<T>,
85 max_length: Option<usize>,
86 padding: bool,
87 truncation: bool,
88 pad_token_id: u32,
89}
90
91impl<T: Tokenizer + Sync> BatchTokenizer<T> {
92 pub fn new(tokenizer: T) -> Self {
94 Self {
95 tokenizer: Arc::new(tokenizer),
96 max_length: None,
97 padding: false,
98 truncation: false,
99 pad_token_id: 0, }
101 }
102
103 pub fn with_max_length(mut self, max_length: usize) -> Self {
105 self.max_length = Some(max_length);
106 self
107 }
108
109 pub fn with_padding(mut self, pad_token_id: u32) -> Self {
111 self.padding = true;
112 self.pad_token_id = pad_token_id;
113 self
114 }
115
116 pub fn with_truncation(mut self) -> Self {
118 self.truncation = true;
119 self
120 }
121
122 pub fn encode_batch_padded(&self, texts: &[&str]) -> Result<BatchedTokenizedInput> {
124 let encoded: Vec<TokenizedInput> = texts
126 .par_iter()
127 .map(|text| self.tokenizer.encode(text))
128 .collect::<Result<Vec<_>>>()?;
129
130 let mut processed = if let (true, Some(max_len)) = (self.truncation, self.max_length) {
132 encoded
133 .into_iter()
134 .map(|mut input| {
135 if input.input_ids.len() > max_len {
136 input.input_ids.truncate(max_len);
137 input.attention_mask.truncate(max_len);
138 if let Some(ref mut type_ids) = input.token_type_ids {
139 type_ids.truncate(max_len);
140 }
141 }
142 input
143 })
144 .collect()
145 } else {
146 encoded
147 };
148
149 if self.padding {
151 let max_len = if let Some(max_len) = self.max_length {
152 max_len
153 } else {
154 processed.iter().map(|input| input.input_ids.len()).max().unwrap_or(0)
155 };
156
157 for input in &mut processed {
158 let current_len = input.input_ids.len();
159 if current_len < max_len {
160 let pad_len = max_len - current_len;
161 input.input_ids.extend(vec![self.pad_token_id; pad_len]);
162 input.attention_mask.extend(vec![0u8; pad_len]);
163 if let Some(ref mut type_ids) = input.token_type_ids {
164 type_ids.extend(vec![0u32; pad_len]);
165 }
166 }
167 }
168 }
169
170 Ok(BatchedTokenizedInput::from_batch(processed))
171 }
172
173 pub fn tokenizer(&self) -> &T {
175 &self.tokenizer
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct BatchedTokenizedInput {
182 pub input_ids: Vec<Vec<u32>>,
183 pub attention_mask: Vec<Vec<u8>>,
184 pub token_type_ids: Option<Vec<Vec<u32>>>,
185}
186
187impl BatchedTokenizedInput {
188 pub fn from_batch(batch: Vec<TokenizedInput>) -> Self {
190 let mut input_ids = Vec::with_capacity(batch.len());
191 let mut attention_mask = Vec::with_capacity(batch.len());
192 let mut token_type_ids = Vec::with_capacity(batch.len());
193
194 let has_token_type_ids = batch.iter().any(|input| input.token_type_ids.is_some());
195
196 for input in batch {
197 input_ids.push(input.input_ids);
198 attention_mask.push(input.attention_mask);
199 if has_token_type_ids {
200 token_type_ids.push(input.token_type_ids.unwrap_or_default());
201 }
202 }
203
204 Self {
205 input_ids,
206 attention_mask,
207 token_type_ids: if has_token_type_ids { Some(token_type_ids) } else { None },
208 }
209 }
210
211 pub fn batch_size(&self) -> usize {
213 self.input_ids.len()
214 }
215
216 pub fn sequence_lengths(&self) -> Vec<usize> {
218 self.input_ids.iter().map(|ids| ids.len()).collect()
219 }
220
221 pub fn to_individual(self) -> Vec<TokenizedInput> {
223 let mut result = Vec::with_capacity(self.input_ids.len());
224
225 for i in 0..self.input_ids.len() {
226 let token_type_ids = self.token_type_ids.as_ref().map(|types| types[i].clone());
227
228 result.push(TokenizedInput {
229 input_ids: self.input_ids[i].clone(),
230 attention_mask: self.attention_mask[i].clone(),
231 token_type_ids,
232 special_tokens_mask: None,
233 offset_mapping: None,
234 overflowing_tokens: None,
235 });
236 }
237
238 result
239 }
240
241 pub fn input_ids_tensor(&self) -> &Vec<Vec<u32>> {
243 &self.input_ids
244 }
245
246 pub fn attention_mask_tensor(&self) -> &Vec<Vec<u8>> {
248 &self.attention_mask
249 }
250
251 pub fn token_type_ids_tensor(&self) -> Option<&Vec<Vec<u32>>> {
253 self.token_type_ids.as_ref()
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::char::CharTokenizer;
261 use std::collections::HashMap;
262
263 fn create_test_tokenizer() -> CharTokenizer {
264 let mut vocab = HashMap::new();
265 vocab.insert("a".to_string(), 0);
266 vocab.insert("b".to_string(), 1);
267 vocab.insert("c".to_string(), 2);
268 vocab.insert(" ".to_string(), 3);
269 vocab.insert("[UNK]".to_string(), 4);
270 vocab.insert("[PAD]".to_string(), 5);
271 vocab.insert("[CLS]".to_string(), 6);
272 vocab.insert("[SEP]".to_string(), 7);
273 CharTokenizer::new(vocab)
274 }
275
276 #[test]
277 fn test_parallel_tokenizer() {
278 let tokenizer = create_test_tokenizer();
279 let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
280
281 let texts = vec!["hello world", "goodbye world", "test text"];
282 let results = parallel_tokenizer.encode_batch(&texts).expect("Operation failed in test");
283
284 assert_eq!(results.len(), 3);
285 for result in results {
286 assert!(!result.input_ids.is_empty());
287 assert!(!result.attention_mask.is_empty());
288 }
289 }
290
291 #[test]
292 fn test_parallel_encode_pairs() {
293 let tokenizer = create_test_tokenizer();
294 let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
295
296 let pairs = vec![("hello", "world"), ("good", "bye"), ("test", "text")];
297 let results =
298 parallel_tokenizer.encode_pair_batch(&pairs).expect("Operation failed in test");
299
300 assert_eq!(results.len(), 3);
301 for result in results {
302 assert!(!result.input_ids.is_empty());
303 assert!(!result.attention_mask.is_empty());
304 }
305 }
306
307 #[test]
308 fn test_batch_tokenizer_with_padding() {
309 let tokenizer = create_test_tokenizer();
310 let batch_tokenizer = BatchTokenizer::new(tokenizer)
311 .with_max_length(10)
312 .with_padding(0)
313 .with_truncation();
314
315 let texts = vec!["short", "this is a longer text", "medium"];
316 let result = batch_tokenizer.encode_batch_padded(&texts).expect("Operation failed in test");
317
318 assert_eq!(result.batch_size(), 3);
319
320 for seq_len in result.sequence_lengths() {
322 assert_eq!(seq_len, 10);
323 }
324 }
325
326 #[test]
327 fn test_batched_tokenized_input() {
328 let input1 = TokenizedInput {
329 input_ids: vec![1, 2, 3],
330 attention_mask: vec![1, 1, 1],
331 token_type_ids: Some(vec![0, 0, 0]),
332 special_tokens_mask: None,
333 offset_mapping: None,
334 overflowing_tokens: None,
335 };
336 let input2 = TokenizedInput {
337 input_ids: vec![4, 5],
338 attention_mask: vec![1, 1],
339 token_type_ids: Some(vec![1, 1]),
340 special_tokens_mask: None,
341 offset_mapping: None,
342 overflowing_tokens: None,
343 };
344
345 let batched = BatchedTokenizedInput::from_batch(vec![input1, input2]);
346
347 assert_eq!(batched.batch_size(), 2);
348 assert_eq!(batched.sequence_lengths(), vec![3, 2]);
349 assert!(batched.token_type_ids.is_some());
350
351 let individual = batched.to_individual();
353 assert_eq!(individual.len(), 2);
354 assert_eq!(individual[0].input_ids, vec![1, 2, 3]);
355 assert_eq!(individual[1].input_ids, vec![4, 5]);
356 }
357
358 #[test]
359 fn test_parallel_decode_batch() {
360 let tokenizer = create_test_tokenizer();
361 let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
362
363 let ids1 = vec![0, 1, 2]; let ids2 = vec![3, 0]; let ids_batch = vec![ids1.as_slice(), ids2.as_slice()];
366
367 let results =
368 parallel_tokenizer.decode_batch(&ids_batch).expect("Operation failed in test");
369 assert_eq!(results.len(), 2);
370 assert!(!results[0].is_empty());
371 assert!(!results[1].is_empty());
372 }
373
374 #[test]
375 fn test_chunk_size_configuration() {
376 let tokenizer = create_test_tokenizer();
377 let mut parallel_tokenizer = ParallelTokenizer::with_chunk_size(tokenizer, 500);
378
379 assert_eq!(parallel_tokenizer.chunk_size(), 500);
380
381 parallel_tokenizer.set_chunk_size(1000);
382 assert_eq!(parallel_tokenizer.chunk_size(), 1000);
383 }
384}