Skip to main content

scirs2_text/
batch_tokenizer.rs

1//! Batch Tokenization with Padding, Truncation, and Attention Masks
2//!
3//! This module provides utilities for tokenizing multiple texts efficiently,
4//! producing padded/truncated sequences with attention masks suitable for
5//! batch inference with transformer models.
6//!
7//! # Example
8//!
9//! ```rust
10//! use scirs2_text::batch_tokenizer::{batch_encode, PaddingStrategy, TruncationStrategy, BatchConfig};
11//! use scirs2_text::tokenizer::{BPETokenizer, TransformerTokenizer};
12//!
13//! let corpus = &["the cat sat on the mat", "the dog sat on the log"];
14//! let tokenizer = BPETokenizer::train(corpus, 100).expect("train failed");
15//!
16//! let texts = &["the cat", "the dog sat"];
17//! let config = BatchConfig {
18//!     max_length: Some(10),
19//!     padding: PaddingStrategy::LongestInBatch,
20//!     truncation: TruncationStrategy::Right,
21//!     pad_token_id: 0,
22//! };
23//! let batch = batch_encode(texts, &tokenizer, &config);
24//! assert_eq!(batch.input_ids.len(), 2);
25//! assert_eq!(batch.attention_mask.len(), 2);
26//! ```
27
28use crate::tokenizer::TransformerTokenizer;
29
30// ---------------------------------------------------------------------------
31// Configuration types
32// ---------------------------------------------------------------------------
33
34/// Strategy for padding sequences to the same length.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum PaddingStrategy {
37    /// No padding: sequences are returned as-is (varying lengths).
38    NoPadding,
39    /// Pad all sequences to `max_length` (requires `max_length` to be set).
40    MaxLength,
41    /// Pad all sequences to the length of the longest sequence in the batch.
42    LongestInBatch,
43}
44
45/// Strategy for truncating sequences that exceed `max_length`.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum TruncationStrategy {
48    /// No truncation: sequences are returned at full length.
49    NoTruncation,
50    /// Truncate from the right (keep the beginning).
51    Right,
52    /// Truncate from the left (keep the end).
53    Left,
54}
55
56/// Side on which to add padding tokens.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum PaddingSide {
59    /// Add padding tokens on the right (default, used by most models).
60    Right,
61    /// Add padding tokens on the left (used by some decoder-only models).
62    Left,
63}
64
65/// Configuration for batch encoding.
66#[derive(Debug, Clone)]
67pub struct BatchConfig {
68    /// Maximum sequence length. When `None`, no length limit is imposed
69    /// (but padding to longest may still apply).
70    pub max_length: Option<usize>,
71    /// Padding strategy.
72    pub padding: PaddingStrategy,
73    /// Truncation strategy.
74    pub truncation: TruncationStrategy,
75    /// Token ID to use for padding.
76    pub pad_token_id: u32,
77}
78
79impl Default for BatchConfig {
80    fn default() -> Self {
81        Self {
82            max_length: None,
83            padding: PaddingStrategy::LongestInBatch,
84            truncation: TruncationStrategy::Right,
85            pad_token_id: 0,
86        }
87    }
88}
89
90/// Extended configuration with padding side.
91#[derive(Debug, Clone)]
92pub struct BatchConfigExt {
93    /// Base configuration.
94    pub base: BatchConfig,
95    /// Side on which to add padding.
96    pub padding_side: PaddingSide,
97}
98
99impl Default for BatchConfigExt {
100    fn default() -> Self {
101        Self {
102            base: BatchConfig::default(),
103            padding_side: PaddingSide::Right,
104        }
105    }
106}
107
108// ---------------------------------------------------------------------------
109// BatchEncoding output
110// ---------------------------------------------------------------------------
111
112/// The result of batch-encoding a set of texts.
113///
114/// All inner vectors have the same outer length (= number of texts). When
115/// padding is enabled, all inner `Vec<u32>` have the same length (= padded
116/// sequence length).
117#[derive(Debug, Clone)]
118pub struct BatchEncoding {
119    /// Token IDs for each input text, possibly padded and/or truncated.
120    pub input_ids: Vec<Vec<u32>>,
121    /// Attention mask: `1` for real tokens, `0` for padding tokens.
122    pub attention_mask: Vec<Vec<u32>>,
123    /// Original (pre-padding, pre-truncation) lengths of each sequence.
124    pub lengths: Vec<usize>,
125}
126
127impl BatchEncoding {
128    /// Number of sequences in the batch.
129    pub fn batch_size(&self) -> usize {
130        self.input_ids.len()
131    }
132
133    /// Length of the (padded) sequences. Returns 0 for an empty batch.
134    pub fn seq_length(&self) -> usize {
135        self.input_ids.first().map_or(0, |v| v.len())
136    }
137
138    /// Return total number of real (non-padding) tokens across the batch.
139    pub fn total_real_tokens(&self) -> usize {
140        self.attention_mask
141            .iter()
142            .flat_map(|mask| mask.iter())
143            .filter(|&&v| v == 1)
144            .count()
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Batch encoding functions
150// ---------------------------------------------------------------------------
151
152/// Truncate a sequence according to the given strategy and max_length.
153fn truncate(ids: &[u32], strategy: TruncationStrategy, max_length: usize) -> Vec<u32> {
154    if ids.len() <= max_length {
155        return ids.to_vec();
156    }
157    match strategy {
158        TruncationStrategy::NoTruncation => ids.to_vec(),
159        TruncationStrategy::Right => ids[..max_length].to_vec(),
160        TruncationStrategy::Left => ids[ids.len() - max_length..].to_vec(),
161    }
162}
163
164/// Pad a sequence to `target_length` with `pad_id`, adding padding on the right.
165fn pad_right(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
166    let real_len = ids.len();
167    if real_len >= target_length {
168        let truncated = &ids[..target_length];
169        let mask = vec![1u32; target_length];
170        return (truncated.to_vec(), mask);
171    }
172    let mut padded = ids.to_vec();
173    let mut mask = vec![1u32; real_len];
174    let pad_count = target_length - real_len;
175    padded.extend(std::iter::repeat_n(pad_id, pad_count));
176    mask.extend(std::iter::repeat_n(0u32, pad_count));
177    (padded, mask)
178}
179
180/// Pad a sequence to `target_length` with `pad_id`, adding padding on the left.
181fn pad_left(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
182    let real_len = ids.len();
183    if real_len >= target_length {
184        let start = real_len - target_length;
185        let truncated = &ids[start..];
186        let mask = vec![1u32; target_length];
187        return (truncated.to_vec(), mask);
188    }
189    let pad_count = target_length - real_len;
190    let mut padded: Vec<u32> = std::iter::repeat_n(pad_id, pad_count).collect();
191    let mut mask: Vec<u32> = std::iter::repeat_n(0u32, pad_count).collect();
192    padded.extend_from_slice(ids);
193    mask.extend(std::iter::repeat_n(1u32, real_len));
194    (padded, mask)
195}
196
197/// Batch-encode multiple texts using a tokenizer.
198///
199/// Each text is independently encoded, then optionally truncated and padded
200/// according to the provided [`BatchConfig`].
201///
202/// Padding is added on the **right** side (standard for most models).
203/// For left-padding, use [`batch_encode_ext`].
204///
205/// # Arguments
206/// - `texts`: slice of input strings
207/// - `tokenizer`: any tokenizer implementing [`TransformerTokenizer`]
208/// - `config`: batch configuration (max_length, padding, truncation, pad token)
209pub fn batch_encode<T: TransformerTokenizer>(
210    texts: &[&str],
211    tokenizer: &T,
212    config: &BatchConfig,
213) -> BatchEncoding {
214    if texts.is_empty() {
215        return BatchEncoding {
216            input_ids: Vec::new(),
217            attention_mask: Vec::new(),
218            lengths: Vec::new(),
219        };
220    }
221
222    // Step 1: Encode all texts
223    let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
224    let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
225
226    // Step 2: Truncation
227    if let Some(max_len) = config.max_length {
228        if config.truncation != TruncationStrategy::NoTruncation {
229            for seq in &mut encoded {
230                *seq = truncate(seq, config.truncation, max_len);
231            }
232        }
233    }
234
235    // Step 3: Determine target length for padding
236    let target_length = match config.padding {
237        PaddingStrategy::NoPadding => {
238            // No padding: return as-is with per-sequence masks
239            let attention_mask: Vec<Vec<u32>> =
240                encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
241            return BatchEncoding {
242                input_ids: encoded,
243                attention_mask,
244                lengths: original_lengths,
245            };
246        }
247        PaddingStrategy::MaxLength => config
248            .max_length
249            .unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
250        PaddingStrategy::LongestInBatch => {
251            let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
252            // If max_length is set, cap at max_length
253            match config.max_length {
254                Some(ml) => longest.min(ml),
255                None => longest,
256            }
257        }
258    };
259
260    // Step 4: Pad sequences
261    let mut input_ids = Vec::with_capacity(encoded.len());
262    let mut attention_mask = Vec::with_capacity(encoded.len());
263
264    for seq in &encoded {
265        let (padded, mask) = pad_right(seq, target_length, config.pad_token_id);
266        input_ids.push(padded);
267        attention_mask.push(mask);
268    }
269
270    BatchEncoding {
271        input_ids,
272        attention_mask,
273        lengths: original_lengths,
274    }
275}
276
277/// Batch-encode with extended configuration (including padding side).
278///
279/// Same as [`batch_encode`] but allows choosing left or right padding.
280pub fn batch_encode_ext<T: TransformerTokenizer>(
281    texts: &[&str],
282    tokenizer: &T,
283    config: &BatchConfigExt,
284) -> BatchEncoding {
285    if texts.is_empty() {
286        return BatchEncoding {
287            input_ids: Vec::new(),
288            attention_mask: Vec::new(),
289            lengths: Vec::new(),
290        };
291    }
292
293    // Step 1: Encode all texts
294    let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
295    let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
296
297    // Step 2: Truncation
298    if let Some(max_len) = config.base.max_length {
299        if config.base.truncation != TruncationStrategy::NoTruncation {
300            for seq in &mut encoded {
301                *seq = truncate(seq, config.base.truncation, max_len);
302            }
303        }
304    }
305
306    // Step 3: Determine target length
307    let target_length = match config.base.padding {
308        PaddingStrategy::NoPadding => {
309            let attention_mask: Vec<Vec<u32>> =
310                encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
311            return BatchEncoding {
312                input_ids: encoded,
313                attention_mask,
314                lengths: original_lengths,
315            };
316        }
317        PaddingStrategy::MaxLength => config
318            .base
319            .max_length
320            .unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
321        PaddingStrategy::LongestInBatch => {
322            let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
323            match config.base.max_length {
324                Some(ml) => longest.min(ml),
325                None => longest,
326            }
327        }
328    };
329
330    // Step 4: Pad sequences
331    let pad_fn = match config.padding_side {
332        PaddingSide::Right => pad_right,
333        PaddingSide::Left => pad_left,
334    };
335
336    let mut input_ids = Vec::with_capacity(encoded.len());
337    let mut attention_mask = Vec::with_capacity(encoded.len());
338
339    for seq in &encoded {
340        let (padded, mask) = pad_fn(seq, target_length, config.base.pad_token_id);
341        input_ids.push(padded);
342        attention_mask.push(mask);
343    }
344
345    BatchEncoding {
346        input_ids,
347        attention_mask,
348        lengths: original_lengths,
349    }
350}
351
352// ---------------------------------------------------------------------------
353// Tests
354// ---------------------------------------------------------------------------
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::tokenizer::BPETokenizer;
360
361    fn train_tokenizer() -> BPETokenizer {
362        let corpus = &[
363            "the cat sat on the mat",
364            "the dog sat on the log",
365            "cats and dogs",
366            "the quick brown fox",
367        ];
368        BPETokenizer::train(corpus, 100).expect("training should succeed")
369    }
370
371    #[test]
372    fn test_batch_encode_basic() {
373        let tok = train_tokenizer();
374        let texts = &["the cat", "the dog sat"];
375        let config = BatchConfig {
376            padding: PaddingStrategy::LongestInBatch,
377            ..Default::default()
378        };
379        let batch = batch_encode(texts, &tok, &config);
380        assert_eq!(batch.batch_size(), 2);
381        // Both sequences should be padded to the same length
382        assert_eq!(batch.input_ids[0].len(), batch.input_ids[1].len());
383        assert_eq!(batch.attention_mask[0].len(), batch.attention_mask[1].len());
384    }
385
386    #[test]
387    fn test_padding_adds_correct_tokens() {
388        let tok = train_tokenizer();
389        let texts = &["the", "the cat sat on the mat"];
390        let config = BatchConfig {
391            padding: PaddingStrategy::LongestInBatch,
392            pad_token_id: 0,
393            ..Default::default()
394        };
395        let batch = batch_encode(texts, &tok, &config);
396
397        // The shorter sequence should have padding
398        let shorter_len = batch.lengths[0];
399        let padded_len = batch.input_ids[0].len();
400        if shorter_len < padded_len {
401            // Padding tokens (0) should appear at the end
402            for i in shorter_len..padded_len {
403                assert_eq!(
404                    batch.input_ids[0][i], 0,
405                    "padding token should be 0 at position {i}"
406                );
407            }
408        }
409    }
410
411    #[test]
412    fn test_attention_mask_correct() {
413        let tok = train_tokenizer();
414        let texts = &["the", "the cat sat"];
415        let config = BatchConfig {
416            padding: PaddingStrategy::LongestInBatch,
417            pad_token_id: 0,
418            ..Default::default()
419        };
420        let batch = batch_encode(texts, &tok, &config);
421
422        // For the shorter sequence, mask should be 1 for real tokens, 0 for padding
423        let shorter_len = batch.lengths[0];
424        for i in 0..shorter_len.min(batch.attention_mask[0].len()) {
425            assert_eq!(
426                batch.attention_mask[0][i], 1,
427                "real token at {i} should have mask 1"
428            );
429        }
430        for i in shorter_len..batch.attention_mask[0].len() {
431            assert_eq!(
432                batch.attention_mask[0][i], 0,
433                "padding at {i} should have mask 0"
434            );
435        }
436    }
437
438    #[test]
439    fn test_truncation_right() {
440        let tok = train_tokenizer();
441        let texts = &["the cat sat on the mat"];
442        let config = BatchConfig {
443            max_length: Some(3),
444            padding: PaddingStrategy::NoPadding,
445            truncation: TruncationStrategy::Right,
446            pad_token_id: 0,
447        };
448        let batch = batch_encode(texts, &tok, &config);
449        assert!(
450            batch.input_ids[0].len() <= 3,
451            "truncated length should be <= 3, got {}",
452            batch.input_ids[0].len()
453        );
454    }
455
456    #[test]
457    fn test_truncation_left() {
458        let tok = train_tokenizer();
459        let texts = &["the cat sat on the mat"];
460        let config = BatchConfig {
461            max_length: Some(3),
462            padding: PaddingStrategy::NoPadding,
463            truncation: TruncationStrategy::Left,
464            pad_token_id: 0,
465        };
466        let batch = batch_encode(texts, &tok, &config);
467        assert!(
468            batch.input_ids[0].len() <= 3,
469            "truncated length should be <= 3"
470        );
471    }
472
473    #[test]
474    fn test_no_padding_varying_lengths() {
475        let tok = train_tokenizer();
476        let texts = &["the", "the cat sat"];
477        let config = BatchConfig {
478            padding: PaddingStrategy::NoPadding,
479            truncation: TruncationStrategy::NoTruncation,
480            ..Default::default()
481        };
482        let batch = batch_encode(texts, &tok, &config);
483        // Sequences should have different lengths
484        assert_eq!(batch.input_ids[0].len(), batch.lengths[0]);
485        assert_eq!(batch.input_ids[1].len(), batch.lengths[1]);
486    }
487
488    #[test]
489    fn test_max_length_padding() {
490        let tok = train_tokenizer();
491        let texts = &["the"];
492        let config = BatchConfig {
493            max_length: Some(10),
494            padding: PaddingStrategy::MaxLength,
495            truncation: TruncationStrategy::Right,
496            pad_token_id: 0,
497        };
498        let batch = batch_encode(texts, &tok, &config);
499        assert_eq!(
500            batch.input_ids[0].len(),
501            10,
502            "should be padded to max_length"
503        );
504    }
505
506    #[test]
507    fn test_empty_input() {
508        let tok = train_tokenizer();
509        let texts: &[&str] = &[];
510        let config = BatchConfig::default();
511        let batch = batch_encode(texts, &tok, &config);
512        assert_eq!(batch.batch_size(), 0);
513    }
514
515    #[test]
516    fn test_empty_string_in_batch() {
517        let tok = train_tokenizer();
518        let texts = &["", "the cat"];
519        let config = BatchConfig {
520            padding: PaddingStrategy::LongestInBatch,
521            pad_token_id: 0,
522            ..Default::default()
523        };
524        let batch = batch_encode(texts, &tok, &config);
525        assert_eq!(batch.batch_size(), 2);
526        // Empty string should produce length 0
527        assert_eq!(batch.lengths[0], 0);
528    }
529
530    #[test]
531    fn test_left_padding() {
532        let tok = train_tokenizer();
533        let texts = &["the", "the cat sat"];
534        let config = BatchConfigExt {
535            base: BatchConfig {
536                padding: PaddingStrategy::LongestInBatch,
537                pad_token_id: 0,
538                ..Default::default()
539            },
540            padding_side: PaddingSide::Left,
541        };
542        let batch = batch_encode_ext(texts, &tok, &config);
543
544        // For the shorter sequence, padding should be at the beginning
545        let shorter_len = batch.lengths[0];
546        let total_len = batch.input_ids[0].len();
547        if shorter_len < total_len {
548            let pad_count = total_len - shorter_len;
549            for i in 0..pad_count {
550                assert_eq!(
551                    batch.attention_mask[0][i], 0,
552                    "left padding mask at {i} should be 0"
553                );
554                assert_eq!(
555                    batch.input_ids[0][i], 0,
556                    "left padding token at {i} should be pad_id"
557                );
558            }
559            for i in pad_count..total_len {
560                assert_eq!(
561                    batch.attention_mask[0][i], 1,
562                    "real token mask at {i} should be 1"
563                );
564            }
565        }
566    }
567
568    #[test]
569    fn test_total_real_tokens() {
570        let tok = train_tokenizer();
571        let texts = &["the cat", "the"];
572        let config = BatchConfig {
573            padding: PaddingStrategy::LongestInBatch,
574            pad_token_id: 0,
575            ..Default::default()
576        };
577        let batch = batch_encode(texts, &tok, &config);
578        let total = batch.total_real_tokens();
579        let expected: usize = batch.lengths.iter().sum();
580        assert_eq!(
581            total, expected,
582            "total real tokens should equal sum of lengths"
583        );
584    }
585
586    #[test]
587    fn test_truncation_with_padding() {
588        let tok = train_tokenizer();
589        let texts = &["the cat sat on the mat", "the"];
590        let config = BatchConfig {
591            max_length: Some(4),
592            padding: PaddingStrategy::MaxLength,
593            truncation: TruncationStrategy::Right,
594            pad_token_id: 0,
595        };
596        let batch = batch_encode(texts, &tok, &config);
597        // All sequences should be exactly max_length
598        for seq in &batch.input_ids {
599            assert_eq!(seq.len(), 4, "all sequences should be padded to max_length");
600        }
601    }
602}