Skip to main content

trustformers_tokenizers/
parallel.rs

1use scirs2_core::parallel_ops::*; // SciRS2 Integration Policy - replaces rayon
2use std::sync::Arc;
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6/// Parallel batch tokenization utilities for improved throughput
7pub struct ParallelTokenizer<T: Tokenizer + Sync> {
8    tokenizer: Arc<T>,
9    chunk_size: usize,
10}
11
12impl<T: Tokenizer + Sync> ParallelTokenizer<T> {
13    /// Create a new parallel tokenizer wrapper
14    pub fn new(tokenizer: T) -> Self {
15        Self {
16            tokenizer: Arc::new(tokenizer),
17            chunk_size: 1000, // Default chunk size
18        }
19    }
20
21    /// Create a new parallel tokenizer wrapper with custom chunk size
22    pub fn with_chunk_size(tokenizer: T, chunk_size: usize) -> Self {
23        Self {
24            tokenizer: Arc::new(tokenizer),
25            chunk_size,
26        }
27    }
28
29    /// Encode a batch of texts in parallel
30    pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
31        texts
32            .par_chunks(self.chunk_size)
33            .map(|chunk| {
34                chunk.iter().map(|text| self.tokenizer.encode(text)).collect::<Result<Vec<_>>>()
35            })
36            .collect::<Result<Vec<Vec<_>>>>()
37            .map(|batches| batches.into_iter().flatten().collect())
38    }
39
40    /// Encode pairs of texts in parallel
41    pub fn encode_pair_batch(&self, text_pairs: &[(&str, &str)]) -> Result<Vec<TokenizedInput>> {
42        text_pairs
43            .par_chunks(self.chunk_size)
44            .map(|chunk| {
45                chunk
46                    .iter()
47                    .map(|(text1, text2)| self.tokenizer.encode_pair(text1, text2))
48                    .collect::<Result<Vec<_>>>()
49            })
50            .collect::<Result<Vec<Vec<_>>>>()
51            .map(|batches| batches.into_iter().flatten().collect())
52    }
53
54    /// Decode a batch of token IDs in parallel
55    pub fn decode_batch(&self, ids_batch: &[&[u32]]) -> Result<Vec<String>> {
56        ids_batch
57            .par_chunks(self.chunk_size)
58            .map(|chunk| {
59                chunk.iter().map(|ids| self.tokenizer.decode(ids)).collect::<Result<Vec<_>>>()
60            })
61            .collect::<Result<Vec<Vec<_>>>>()
62            .map(|batches| batches.into_iter().flatten().collect())
63    }
64
65    /// Get the underlying tokenizer
66    pub fn tokenizer(&self) -> &T {
67        &self.tokenizer
68    }
69
70    /// Get the chunk size
71    pub fn chunk_size(&self) -> usize {
72        self.chunk_size
73    }
74
75    /// Set the chunk size for batching
76    pub fn set_chunk_size(&mut self, chunk_size: usize) {
77        self.chunk_size = chunk_size;
78    }
79}
80
81/// Batch tokenization with padding and truncation support
82#[derive(Debug, Clone)]
83pub struct BatchTokenizer<T: Tokenizer + Sync> {
84    tokenizer: Arc<T>,
85    max_length: Option<usize>,
86    padding: bool,
87    truncation: bool,
88    pad_token_id: u32,
89}
90
91impl<T: Tokenizer + Sync> BatchTokenizer<T> {
92    /// Create a new batch tokenizer
93    pub fn new(tokenizer: T) -> Self {
94        Self {
95            tokenizer: Arc::new(tokenizer),
96            max_length: None,
97            padding: false,
98            truncation: false,
99            pad_token_id: 0, // Default pad token ID
100        }
101    }
102
103    /// Set the maximum sequence length
104    pub fn with_max_length(mut self, max_length: usize) -> Self {
105        self.max_length = Some(max_length);
106        self
107    }
108
109    /// Enable padding to max length
110    pub fn with_padding(mut self, pad_token_id: u32) -> Self {
111        self.padding = true;
112        self.pad_token_id = pad_token_id;
113        self
114    }
115
116    /// Enable truncation to max length
117    pub fn with_truncation(mut self) -> Self {
118        self.truncation = true;
119        self
120    }
121
122    /// Encode a batch with padding and truncation
123    pub fn encode_batch_padded(&self, texts: &[&str]) -> Result<BatchedTokenizedInput> {
124        // First, encode all texts in parallel
125        let encoded: Vec<TokenizedInput> = texts
126            .par_iter()
127            .map(|text| self.tokenizer.encode(text))
128            .collect::<Result<Vec<_>>>()?;
129
130        // Apply truncation if enabled
131        let mut processed = if let (true, Some(max_len)) = (self.truncation, self.max_length) {
132            encoded
133                .into_iter()
134                .map(|mut input| {
135                    if input.input_ids.len() > max_len {
136                        input.input_ids.truncate(max_len);
137                        input.attention_mask.truncate(max_len);
138                        if let Some(ref mut type_ids) = input.token_type_ids {
139                            type_ids.truncate(max_len);
140                        }
141                    }
142                    input
143                })
144                .collect()
145        } else {
146            encoded
147        };
148
149        // Apply padding if enabled
150        if self.padding {
151            let max_len = if let Some(max_len) = self.max_length {
152                max_len
153            } else {
154                processed.iter().map(|input| input.input_ids.len()).max().unwrap_or(0)
155            };
156
157            for input in &mut processed {
158                let current_len = input.input_ids.len();
159                if current_len < max_len {
160                    let pad_len = max_len - current_len;
161                    input.input_ids.extend(vec![self.pad_token_id; pad_len]);
162                    input.attention_mask.extend(vec![0u8; pad_len]);
163                    if let Some(ref mut type_ids) = input.token_type_ids {
164                        type_ids.extend(vec![0u32; pad_len]);
165                    }
166                }
167            }
168        }
169
170        Ok(BatchedTokenizedInput::from_batch(processed))
171    }
172
173    /// Get the underlying tokenizer
174    pub fn tokenizer(&self) -> &T {
175        &self.tokenizer
176    }
177}
178
179/// Batched tokenized input with convenient access methods
180#[derive(Debug, Clone)]
181pub struct BatchedTokenizedInput {
182    pub input_ids: Vec<Vec<u32>>,
183    pub attention_mask: Vec<Vec<u8>>,
184    pub token_type_ids: Option<Vec<Vec<u32>>>,
185}
186
187impl BatchedTokenizedInput {
188    /// Create from a batch of TokenizedInput
189    pub fn from_batch(batch: Vec<TokenizedInput>) -> Self {
190        let mut input_ids = Vec::with_capacity(batch.len());
191        let mut attention_mask = Vec::with_capacity(batch.len());
192        let mut token_type_ids = Vec::with_capacity(batch.len());
193
194        let has_token_type_ids = batch.iter().any(|input| input.token_type_ids.is_some());
195
196        for input in batch {
197            input_ids.push(input.input_ids);
198            attention_mask.push(input.attention_mask);
199            if has_token_type_ids {
200                token_type_ids.push(input.token_type_ids.unwrap_or_default());
201            }
202        }
203
204        Self {
205            input_ids,
206            attention_mask,
207            token_type_ids: if has_token_type_ids { Some(token_type_ids) } else { None },
208        }
209    }
210
211    /// Get the batch size
212    pub fn batch_size(&self) -> usize {
213        self.input_ids.len()
214    }
215
216    /// Get the sequence length for each sample
217    pub fn sequence_lengths(&self) -> Vec<usize> {
218        self.input_ids.iter().map(|ids| ids.len()).collect()
219    }
220
221    /// Convert to individual TokenizedInput items
222    pub fn to_individual(self) -> Vec<TokenizedInput> {
223        let mut result = Vec::with_capacity(self.input_ids.len());
224
225        for i in 0..self.input_ids.len() {
226            let token_type_ids = self.token_type_ids.as_ref().map(|types| types[i].clone());
227
228            result.push(TokenizedInput {
229                input_ids: self.input_ids[i].clone(),
230                attention_mask: self.attention_mask[i].clone(),
231                token_type_ids,
232                special_tokens_mask: None,
233                offset_mapping: None,
234                overflowing_tokens: None,
235            });
236        }
237
238        result
239    }
240
241    /// Get input IDs as a flat tensor-like structure
242    pub fn input_ids_tensor(&self) -> &Vec<Vec<u32>> {
243        &self.input_ids
244    }
245
246    /// Get attention mask as a flat tensor-like structure
247    pub fn attention_mask_tensor(&self) -> &Vec<Vec<u8>> {
248        &self.attention_mask
249    }
250
251    /// Get token type IDs as a flat tensor-like structure
252    pub fn token_type_ids_tensor(&self) -> Option<&Vec<Vec<u32>>> {
253        self.token_type_ids.as_ref()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::char::CharTokenizer;
261    use std::collections::HashMap;
262
263    fn create_test_tokenizer() -> CharTokenizer {
264        let mut vocab = HashMap::new();
265        vocab.insert("a".to_string(), 0);
266        vocab.insert("b".to_string(), 1);
267        vocab.insert("c".to_string(), 2);
268        vocab.insert(" ".to_string(), 3);
269        vocab.insert("[UNK]".to_string(), 4);
270        vocab.insert("[PAD]".to_string(), 5);
271        vocab.insert("[CLS]".to_string(), 6);
272        vocab.insert("[SEP]".to_string(), 7);
273        CharTokenizer::new(vocab)
274    }
275
276    #[test]
277    fn test_parallel_tokenizer() {
278        let tokenizer = create_test_tokenizer();
279        let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
280
281        let texts = vec!["hello world", "goodbye world", "test text"];
282        let results = parallel_tokenizer.encode_batch(&texts).expect("Operation failed in test");
283
284        assert_eq!(results.len(), 3);
285        for result in results {
286            assert!(!result.input_ids.is_empty());
287            assert!(!result.attention_mask.is_empty());
288        }
289    }
290
291    #[test]
292    fn test_parallel_encode_pairs() {
293        let tokenizer = create_test_tokenizer();
294        let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
295
296        let pairs = vec![("hello", "world"), ("good", "bye"), ("test", "text")];
297        let results =
298            parallel_tokenizer.encode_pair_batch(&pairs).expect("Operation failed in test");
299
300        assert_eq!(results.len(), 3);
301        for result in results {
302            assert!(!result.input_ids.is_empty());
303            assert!(!result.attention_mask.is_empty());
304        }
305    }
306
307    #[test]
308    fn test_batch_tokenizer_with_padding() {
309        let tokenizer = create_test_tokenizer();
310        let batch_tokenizer = BatchTokenizer::new(tokenizer)
311            .with_max_length(10)
312            .with_padding(0)
313            .with_truncation();
314
315        let texts = vec!["short", "this is a longer text", "medium"];
316        let result = batch_tokenizer.encode_batch_padded(&texts).expect("Operation failed in test");
317
318        assert_eq!(result.batch_size(), 3);
319
320        // All sequences should have the same length (10) due to padding/truncation
321        for seq_len in result.sequence_lengths() {
322            assert_eq!(seq_len, 10);
323        }
324    }
325
326    #[test]
327    fn test_batched_tokenized_input() {
328        let input1 = TokenizedInput {
329            input_ids: vec![1, 2, 3],
330            attention_mask: vec![1, 1, 1],
331            token_type_ids: Some(vec![0, 0, 0]),
332            special_tokens_mask: None,
333            offset_mapping: None,
334            overflowing_tokens: None,
335        };
336        let input2 = TokenizedInput {
337            input_ids: vec![4, 5],
338            attention_mask: vec![1, 1],
339            token_type_ids: Some(vec![1, 1]),
340            special_tokens_mask: None,
341            offset_mapping: None,
342            overflowing_tokens: None,
343        };
344
345        let batched = BatchedTokenizedInput::from_batch(vec![input1, input2]);
346
347        assert_eq!(batched.batch_size(), 2);
348        assert_eq!(batched.sequence_lengths(), vec![3, 2]);
349        assert!(batched.token_type_ids.is_some());
350
351        // Test conversion back to individual
352        let individual = batched.to_individual();
353        assert_eq!(individual.len(), 2);
354        assert_eq!(individual[0].input_ids, vec![1, 2, 3]);
355        assert_eq!(individual[1].input_ids, vec![4, 5]);
356    }
357
358    #[test]
359    fn test_parallel_decode_batch() {
360        let tokenizer = create_test_tokenizer();
361        let parallel_tokenizer = ParallelTokenizer::new(tokenizer);
362
363        let ids1 = vec![0, 1, 2]; // a, b, c
364        let ids2 = vec![3, 0]; // space, a
365        let ids_batch = vec![ids1.as_slice(), ids2.as_slice()];
366
367        let results =
368            parallel_tokenizer.decode_batch(&ids_batch).expect("Operation failed in test");
369        assert_eq!(results.len(), 2);
370        assert!(!results[0].is_empty());
371        assert!(!results[1].is_empty());
372    }
373
374    #[test]
375    fn test_chunk_size_configuration() {
376        let tokenizer = create_test_tokenizer();
377        let mut parallel_tokenizer = ParallelTokenizer::with_chunk_size(tokenizer, 500);
378
379        assert_eq!(parallel_tokenizer.chunk_size(), 500);
380
381        parallel_tokenizer.set_chunk_size(1000);
382        assert_eq!(parallel_tokenizer.chunk_size(), 1000);
383    }
384}