Skip to main content

trustformers_tokenizers/
sequence_packing.rs

1use std::collections::HashMap;
2use trustformers_core::errors::{Result, TrustformersError};
3use trustformers_core::traits::TokenizedInput;
4
5/// Configuration for sequence packing
6#[derive(Debug, Clone)]
7pub struct PackingConfig {
8    /// Maximum sequence length after packing
9    pub max_packed_length: usize,
10
11    /// Padding token ID
12    pub pad_token_id: u32,
13
14    /// Separator token ID (used between packed sequences)
15    pub sep_token_id: Option<u32>,
16
17    /// Whether to add separator tokens between sequences
18    pub add_separators: bool,
19
20    /// Minimum sequence length to consider for packing
21    pub min_sequence_length: usize,
22
23    /// Maximum number of sequences to pack together
24    pub max_sequences_per_pack: usize,
25
26    /// Packing strategy to use
27    pub strategy: PackingStrategy,
28
29    /// Whether to preserve sequence boundaries in attention masks
30    pub preserve_boundaries: bool,
31}
32
33impl Default for PackingConfig {
34    fn default() -> Self {
35        Self {
36            max_packed_length: 512,
37            pad_token_id: 0,
38            sep_token_id: None,
39            add_separators: false,
40            min_sequence_length: 10,
41            max_sequences_per_pack: 4,
42            strategy: PackingStrategy::FirstFit,
43            preserve_boundaries: true,
44        }
45    }
46}
47
48/// Different strategies for packing sequences
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum PackingStrategy {
51    /// Pack sequences in order, fitting as many as possible
52    FirstFit,
53
54    /// Sort by length and pack optimally
55    BestFit,
56
57    /// Group similar length sequences together
58    SimilarLength,
59
60    /// Random shuffling before packing
61    Random,
62}
63
64/// Information about how sequences were packed
65#[derive(Debug, Clone)]
66pub struct PackingInfo {
67    /// Original sequence indices that were packed together
68    pub original_indices: Vec<usize>,
69
70    /// Start and end positions of each sequence in the packed sequence
71    pub sequence_boundaries: Vec<(usize, usize)>,
72
73    /// Number of sequences packed together
74    pub num_sequences: usize,
75
76    /// Total length of the packed sequence (excluding padding)
77    pub packed_length: usize,
78
79    /// Efficiency ratio (used length / max length)
80    pub efficiency: f32,
81}
82
83/// A packed sequence with metadata
84#[derive(Debug, Clone)]
85pub struct PackedSequence {
86    /// The packed tokenized input
87    pub tokenized_input: TokenizedInput,
88
89    /// Information about how this was packed
90    pub packing_info: PackingInfo,
91
92    /// Token type IDs for each token (0 for first sequence, 1 for second, etc.)
93    pub sequence_ids: Vec<u32>,
94}
95
96/// Statistics about the packing process
97#[derive(Debug, Clone)]
98pub struct PackingStats {
99    /// Total number of original sequences
100    pub total_sequences: usize,
101
102    /// Number of packed sequences produced
103    pub num_packed_sequences: usize,
104
105    /// Average number of sequences per pack
106    pub avg_sequences_per_pack: f32,
107
108    /// Average efficiency (utilization ratio)
109    pub avg_efficiency: f32,
110
111    /// Number of sequences that couldn't be packed
112    pub unpacked_sequences: usize,
113
114    /// Total tokens saved through packing
115    pub tokens_saved: usize,
116
117    /// Compression ratio (original tokens / packed tokens)
118    pub compression_ratio: f32,
119}
120
121/// Main sequence packing utility
122pub struct SequencePacker {
123    config: PackingConfig,
124}
125
126impl SequencePacker {
127    /// Create a new sequence packer with the given configuration
128    pub fn new(config: PackingConfig) -> Self {
129        Self { config }
130    }
131
132    /// Pack a batch of tokenized inputs
133    pub fn pack_sequences(
134        &self,
135        sequences: &[TokenizedInput],
136    ) -> Result<(Vec<PackedSequence>, PackingStats)> {
137        if sequences.is_empty() {
138            return Ok((vec![], PackingStats::default()));
139        }
140
141        // Prepare sequences for packing
142        let mut seq_items: Vec<SequenceItem> = sequences
143            .iter()
144            .enumerate()
145            .map(|(idx, seq)| SequenceItem {
146                index: idx,
147                length: seq.input_ids.len(),
148                tokenized_input: seq.clone(),
149            })
150            .collect();
151
152        // Filter out sequences that are too long or too short
153        seq_items.retain(|item| {
154            item.length >= self.config.min_sequence_length
155                && item.length <= self.config.max_packed_length
156        });
157
158        // Apply packing strategy
159        self.apply_packing_strategy(&mut seq_items);
160
161        // Pack sequences
162        let packed_sequences = self.pack_sequences_greedy(&seq_items)?;
163
164        // Calculate statistics
165        let stats = self.calculate_stats(sequences.len(), &packed_sequences);
166
167        Ok((packed_sequences, stats))
168    }
169
170    /// Unpack a packed sequence back to individual sequences
171    pub fn unpack_sequence(&self, packed: &PackedSequence) -> Result<Vec<TokenizedInput>> {
172        let mut sequences = Vec::new();
173
174        for (start, end) in &packed.packing_info.sequence_boundaries {
175            if *end > packed.tokenized_input.input_ids.len() {
176                return Err(TrustformersError::invalid_input(
177                    "Invalid sequence boundary in packed sequence".to_string(),
178                ));
179            }
180
181            let input_ids = packed.tokenized_input.input_ids[*start..*end].to_vec();
182            let attention_mask = packed.tokenized_input.attention_mask[*start..*end].to_vec();
183
184            let token_type_ids = packed
185                .tokenized_input
186                .token_type_ids
187                .as_ref()
188                .map(|ttids| ttids[*start..*end].to_vec());
189
190            sequences.push(TokenizedInput {
191                input_ids,
192                attention_mask,
193                token_type_ids,
194                special_tokens_mask: None,
195                offset_mapping: None,
196                overflowing_tokens: None,
197            });
198        }
199
200        Ok(sequences)
201    }
202
203    /// Apply the configured packing strategy to sort sequences
204    fn apply_packing_strategy(&self, seq_items: &mut [SequenceItem]) {
205        match self.config.strategy {
206            PackingStrategy::FirstFit => {
207                // No sorting needed, use original order
208            },
209            PackingStrategy::BestFit => {
210                // Sort by length descending for better bin packing
211                seq_items.sort_by_key(|item| std::cmp::Reverse(item.length));
212            },
213            PackingStrategy::SimilarLength => {
214                // Sort by length ascending to group similar lengths
215                seq_items.sort_by_key(|a| a.length);
216            },
217            PackingStrategy::Random => {
218                // Shuffle randomly
219                use scirs2_core::random::*;  // SciRS2 Integration Policy
220                use scirs2_core::random::SliceRandom;  // Explicit for trait methods
221                let mut rng = thread_rng();
222                seq_items.shuffle(rng.rng_mut());
223            },
224        }
225    }
226
227    /// Pack sequences using a greedy algorithm
228    fn pack_sequences_greedy(&self, seq_items: &[SequenceItem]) -> Result<Vec<PackedSequence>> {
229        let mut packed_sequences = Vec::new();
230        let mut used = vec![false; seq_items.len()];
231
232        for i in 0..seq_items.len() {
233            if used[i] {
234                continue;
235            }
236
237            let mut current_pack = vec![i];
238            let mut current_length = seq_items[i].length;
239            used[i] = true;
240
241            // Add separators if configured
242            if self.config.add_separators && self.config.sep_token_id.is_some() {
243                current_length += 1; // Space for separator
244            }
245
246            // Try to fit more sequences
247            for j in (i + 1)..seq_items.len() {
248                if used[j] || current_pack.len() >= self.config.max_sequences_per_pack {
249                    continue;
250                }
251
252                let additional_length = seq_items[j].length;
253                let separator_length =
254                    if self.config.add_separators && self.config.sep_token_id.is_some() {
255                        1
256                    } else {
257                        0
258                    };
259
260                if current_length + additional_length + separator_length
261                    <= self.config.max_packed_length
262                {
263                    current_pack.push(j);
264                    current_length += additional_length + separator_length;
265                    used[j] = true;
266                }
267            }
268
269            // Create packed sequence
270            let packed = self.create_packed_sequence(&current_pack, seq_items)?;
271            packed_sequences.push(packed);
272        }
273
274        Ok(packed_sequences)
275    }
276
277    /// Create a packed sequence from a group of sequence indices
278    fn create_packed_sequence(
279        &self,
280        indices: &[usize],
281        seq_items: &[SequenceItem],
282    ) -> Result<PackedSequence> {
283        let mut packed_input_ids = Vec::new();
284        let mut packed_attention_mask = Vec::new();
285        let mut packed_token_type_ids: Vec<u32> = Vec::new();
286        let mut sequence_ids = Vec::new();
287        let mut sequence_boundaries = Vec::new();
288
289        for (seq_idx, &item_idx) in indices.iter().enumerate() {
290            let item = &seq_items[item_idx];
291            let start_pos = packed_input_ids.len();
292
293            // Add the sequence
294            packed_input_ids.extend(&item.tokenized_input.input_ids);
295            packed_attention_mask.extend(&item.tokenized_input.attention_mask);
296
297            // Add token type IDs
298            if let Some(ref ttids) = item.tokenized_input.token_type_ids {
299                packed_token_type_ids.extend(ttids);
300            } else {
301                packed_token_type_ids.extend(vec![0u32; item.tokenized_input.input_ids.len()]);
302            }
303
304            // Add sequence IDs for tracking
305            sequence_ids.extend(vec![seq_idx as u32; item.tokenized_input.input_ids.len()]);
306
307            let end_pos = packed_input_ids.len();
308            sequence_boundaries.push((start_pos, end_pos));
309
310            // Add separator if not the last sequence and separators are enabled
311            if seq_idx < indices.len() - 1 && self.config.add_separators {
312                if let Some(sep_token_id) = self.config.sep_token_id {
313                    packed_input_ids.push(sep_token_id);
314                    packed_attention_mask.push(1);
315                    packed_token_type_ids.push(0u32);
316                    sequence_ids.push(seq_idx as u32);
317                }
318            }
319        }
320
321        // Pad to max length if needed
322        let current_length = packed_input_ids.len();
323        if current_length < self.config.max_packed_length {
324            let padding_length = self.config.max_packed_length - current_length;
325            packed_input_ids.extend(vec![self.config.pad_token_id; padding_length]);
326            packed_attention_mask.extend(vec![0u8; padding_length]);
327            packed_token_type_ids.extend(vec![0u32; padding_length]);
328            sequence_ids.extend(vec![u32::MAX; padding_length]); // Use MAX to indicate padding
329        }
330
331        let packing_info = PackingInfo {
332            original_indices: indices.iter().map(|&i| seq_items[i].index).collect(),
333            sequence_boundaries,
334            num_sequences: indices.len(),
335            packed_length: current_length,
336            efficiency: current_length as f32 / self.config.max_packed_length as f32,
337        };
338
339        let tokenized_input = TokenizedInput {
340            input_ids: packed_input_ids,
341            attention_mask: packed_attention_mask,
342            token_type_ids: Some(packed_token_type_ids),
343            special_tokens_mask: None,
344            offset_mapping: None,
345            overflowing_tokens: None,
346        };
347
348        Ok(PackedSequence {
349            tokenized_input,
350            packing_info,
351            sequence_ids,
352        })
353    }
354
355    /// Calculate packing statistics
356    fn calculate_stats(
357        &self,
358        original_count: usize,
359        packed_sequences: &[PackedSequence],
360    ) -> PackingStats {
361        let total_packed_sequences = packed_sequences.len();
362        let total_sequences_packed: usize =
363            packed_sequences.iter().map(|p| p.packing_info.num_sequences).sum();
364
365        let avg_sequences_per_pack = if total_packed_sequences > 0 {
366            total_sequences_packed as f32 / total_packed_sequences as f32
367        } else {
368            0.0
369        };
370
371        let avg_efficiency = if total_packed_sequences > 0 {
372            packed_sequences.iter().map(|p| p.packing_info.efficiency).sum::<f32>()
373                / total_packed_sequences as f32
374        } else {
375            0.0
376        };
377
378        let unpacked_sequences = original_count.saturating_sub(total_sequences_packed);
379
380        // Calculate token savings
381        let original_tokens_if_padded = original_count * self.config.max_packed_length;
382        let actual_tokens_used: usize =
383            packed_sequences.iter().map(|_p| self.config.max_packed_length).sum();
384        let tokens_saved = original_tokens_if_padded.saturating_sub(actual_tokens_used);
385
386        let compression_ratio = if actual_tokens_used > 0 {
387            original_tokens_if_padded as f32 / actual_tokens_used as f32
388        } else {
389            1.0
390        };
391
392        PackingStats {
393            total_sequences: original_count,
394            num_packed_sequences: total_packed_sequences,
395            avg_sequences_per_pack,
396            avg_efficiency,
397            unpacked_sequences,
398            tokens_saved,
399            compression_ratio,
400        }
401    }
402}
403
404impl Default for PackingStats {
405    fn default() -> Self {
406        Self {
407            total_sequences: 0,
408            num_packed_sequences: 0,
409            avg_sequences_per_pack: 0.0,
410            avg_efficiency: 0.0,
411            unpacked_sequences: 0,
412            tokens_saved: 0,
413            compression_ratio: 1.0,
414        }
415    }
416}
417
418/// Internal representation of a sequence for packing
419#[derive(Debug, Clone)]
420struct SequenceItem {
421    index: usize,
422    length: usize,
423    tokenized_input: TokenizedInput,
424}
425
426/// Advanced sequence packer with additional features
427pub struct AdvancedSequencePacker {
428    base_packer: SequencePacker,
429    length_histogram: HashMap<usize, usize>,
430    #[allow(dead_code)]
431    packing_cache: HashMap<Vec<usize>, PackedSequence>,
432}
433
434impl AdvancedSequencePacker {
435    /// Create a new advanced sequence packer
436    pub fn new(config: PackingConfig) -> Self {
437        Self {
438            base_packer: SequencePacker::new(config),
439            length_histogram: HashMap::new(),
440            packing_cache: HashMap::new(),
441        }
442    }
443
444    /// Pack sequences with length-aware optimization
445    pub fn pack_with_optimization(
446        &mut self,
447        sequences: &[TokenizedInput],
448    ) -> Result<(Vec<PackedSequence>, PackingStats)> {
449        // Update length histogram
450        self.update_length_histogram(sequences);
451
452        // Use the base packer but with optimized strategy
453        self.base_packer.pack_sequences(sequences)
454    }
455
456    /// Update the length histogram for optimization
457    fn update_length_histogram(&mut self, sequences: &[TokenizedInput]) {
458        for seq in sequences {
459            let length = seq.input_ids.len();
460            *self.length_histogram.entry(length).or_insert(0) += 1;
461        }
462    }
463
464    /// Get length distribution statistics
465    pub fn get_length_stats(&self) -> Vec<(usize, usize)> {
466        let mut stats: Vec<_> =
467            self.length_histogram.iter().map(|(&len, &count)| (len, count)).collect();
468        stats.sort_by_key(|&(len, _)| len);
469        stats
470    }
471
472    /// Suggest optimal packing configuration based on observed data
473    pub fn suggest_config(&self) -> PackingConfig {
474        let mut config = self.base_packer.config.clone();
475
476        if !self.length_histogram.is_empty() {
477            // Calculate percentiles
478            let total_sequences: usize = self.length_histogram.values().sum();
479            let mut cumulative = 0;
480            let mut percentile_95 = 0;
481
482            for (&length, &count) in &self.length_histogram {
483                cumulative += count;
484                if cumulative >= (total_sequences * 95) / 100 {
485                    percentile_95 = length;
486                    break;
487                }
488            }
489
490            // Suggest max length based on 95th percentile
491            if percentile_95 > 0 {
492                config.max_packed_length = (percentile_95 * 2).max(512);
493            }
494
495            // Suggest strategy based on length distribution
496            let length_variance = self.calculate_length_variance();
497            if length_variance < 100.0 {
498                config.strategy = PackingStrategy::SimilarLength;
499            } else {
500                config.strategy = PackingStrategy::BestFit;
501            }
502        }
503
504        config
505    }
506
507    /// Calculate variance in sequence lengths
508    fn calculate_length_variance(&self) -> f64 {
509        if self.length_histogram.is_empty() {
510            return 0.0;
511        }
512
513        let total_sequences: usize = self.length_histogram.values().sum();
514        let mean: f64 = self
515            .length_histogram
516            .iter()
517            .map(|(&len, &count)| len as f64 * count as f64)
518            .sum::<f64>()
519            / total_sequences as f64;
520
521        let variance: f64 = self
522            .length_histogram
523            .iter()
524            .map(|(&len, &count)| {
525                let diff = len as f64 - mean;
526                diff * diff * count as f64
527            })
528            .sum::<f64>()
529            / total_sequences as f64;
530
531        variance
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    fn create_test_sequence(length: usize) -> TokenizedInput {
540        TokenizedInput {
541            input_ids: (0..length).map(|i| i as u32).collect(),
542            attention_mask: vec![1u8; length],
543            token_type_ids: Some(vec![0u32; length]),
544            special_tokens_mask: None,
545            offset_mapping: None,
546            overflowing_tokens: None,
547        }
548    }
549
550    #[test]
551    fn test_basic_packing() {
552        let config = PackingConfig {
553            max_packed_length: 100,
554            pad_token_id: 0,
555            ..Default::default()
556        };
557        let packer = SequencePacker::new(config);
558
559        let sequences = vec![
560            create_test_sequence(30),
561            create_test_sequence(25),
562            create_test_sequence(40),
563        ];
564
565        let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
566
567        assert!(!packed.is_empty());
568        assert_eq!(stats.total_sequences, 3);
569    }
570
571    #[test]
572    fn test_packing_with_separators() {
573        let config = PackingConfig {
574            max_packed_length: 100,
575            pad_token_id: 0,
576            sep_token_id: Some(999),
577            add_separators: true,
578            ..Default::default()
579        };
580        let packer = SequencePacker::new(config);
581
582        let sequences = vec![create_test_sequence(20), create_test_sequence(20)];
583
584        let (packed, _) = packer.pack_sequences(&sequences).expect("Operation failed in test");
585
586        assert!(!packed.is_empty());
587        // Should have separator between sequences
588        assert!(packed[0].tokenized_input.input_ids.contains(&999));
589    }
590
591    #[test]
592    fn test_unpacking() {
593        let config = PackingConfig {
594            max_packed_length: 100,
595            pad_token_id: 0,
596            ..Default::default()
597        };
598        let packer = SequencePacker::new(config);
599
600        let original_sequences = vec![create_test_sequence(30), create_test_sequence(25)];
601
602        let (packed, _) =
603            packer.pack_sequences(&original_sequences).expect("Operation failed in test");
604        let unpacked = packer.unpack_sequence(&packed[0]).expect("Operation failed in test");
605
606        assert_eq!(unpacked.len(), packed[0].packing_info.num_sequences);
607    }
608
609    #[test]
610    fn test_packing_strategies() {
611        let config = PackingConfig {
612            max_packed_length: 100,
613            strategy: PackingStrategy::BestFit,
614            ..Default::default()
615        };
616        let packer = SequencePacker::new(config);
617
618        let sequences = vec![
619            create_test_sequence(80),
620            create_test_sequence(10),
621            create_test_sequence(15),
622            create_test_sequence(20),
623        ];
624
625        let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
626
627        assert!(!packed.is_empty());
628        assert!(stats.avg_efficiency > 0.0);
629    }
630
631    #[test]
632    fn test_advanced_packer() {
633        let config = PackingConfig::default();
634        let mut advanced_packer = AdvancedSequencePacker::new(config);
635
636        let sequences = vec![
637            create_test_sequence(50),
638            create_test_sequence(55),
639            create_test_sequence(48),
640            create_test_sequence(52),
641        ];
642
643        let (packed, stats) = advanced_packer
644            .pack_with_optimization(&sequences)
645            .expect("Operation failed in test");
646
647        assert!(!packed.is_empty());
648        assert_eq!(stats.total_sequences, 4);
649
650        let length_stats = advanced_packer.get_length_stats();
651        assert!(!length_stats.is_empty());
652
653        let suggested_config = advanced_packer.suggest_config();
654        assert!(suggested_config.max_packed_length > 0);
655    }
656
657    #[test]
658    fn test_efficiency_calculation() {
659        let config = PackingConfig {
660            max_packed_length: 100,
661            ..Default::default()
662        };
663        let packer = SequencePacker::new(config);
664
665        // Perfect packing scenario
666        let sequences = vec![create_test_sequence(50), create_test_sequence(50)];
667
668        let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
669
670        assert_eq!(packed.len(), 1);
671        assert_eq!(packed[0].packing_info.efficiency, 1.0); // Perfect efficiency
672        assert!(stats.avg_efficiency > 0.9);
673    }
674
675    #[test]
676    fn test_max_sequences_per_pack() {
677        let config = PackingConfig {
678            max_packed_length: 1000,
679            max_sequences_per_pack: 2,
680            ..Default::default()
681        };
682        let packer = SequencePacker::new(config);
683
684        let sequences = vec![
685            create_test_sequence(10),
686            create_test_sequence(10),
687            create_test_sequence(10),
688            create_test_sequence(10),
689        ];
690
691        let (packed, _) = packer.pack_sequences(&sequences).expect("Operation failed in test");
692
693        // Should create 2 packs with max 2 sequences each
694        assert_eq!(packed.len(), 2);
695        for pack in packed {
696            assert!(pack.packing_info.num_sequences <= 2);
697        }
698    }
699}