1use std::collections::HashMap;
2use trustformers_core::errors::{Result, TrustformersError};
3use trustformers_core::traits::TokenizedInput;
4
5#[derive(Debug, Clone)]
7pub struct PackingConfig {
8 pub max_packed_length: usize,
10
11 pub pad_token_id: u32,
13
14 pub sep_token_id: Option<u32>,
16
17 pub add_separators: bool,
19
20 pub min_sequence_length: usize,
22
23 pub max_sequences_per_pack: usize,
25
26 pub strategy: PackingStrategy,
28
29 pub preserve_boundaries: bool,
31}
32
33impl Default for PackingConfig {
34 fn default() -> Self {
35 Self {
36 max_packed_length: 512,
37 pad_token_id: 0,
38 sep_token_id: None,
39 add_separators: false,
40 min_sequence_length: 10,
41 max_sequences_per_pack: 4,
42 strategy: PackingStrategy::FirstFit,
43 preserve_boundaries: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum PackingStrategy {
51 FirstFit,
53
54 BestFit,
56
57 SimilarLength,
59
60 Random,
62}
63
64#[derive(Debug, Clone)]
66pub struct PackingInfo {
67 pub original_indices: Vec<usize>,
69
70 pub sequence_boundaries: Vec<(usize, usize)>,
72
73 pub num_sequences: usize,
75
76 pub packed_length: usize,
78
79 pub efficiency: f32,
81}
82
83#[derive(Debug, Clone)]
85pub struct PackedSequence {
86 pub tokenized_input: TokenizedInput,
88
89 pub packing_info: PackingInfo,
91
92 pub sequence_ids: Vec<u32>,
94}
95
96#[derive(Debug, Clone)]
98pub struct PackingStats {
99 pub total_sequences: usize,
101
102 pub num_packed_sequences: usize,
104
105 pub avg_sequences_per_pack: f32,
107
108 pub avg_efficiency: f32,
110
111 pub unpacked_sequences: usize,
113
114 pub tokens_saved: usize,
116
117 pub compression_ratio: f32,
119}
120
121pub struct SequencePacker {
123 config: PackingConfig,
124}
125
126impl SequencePacker {
127 pub fn new(config: PackingConfig) -> Self {
129 Self { config }
130 }
131
132 pub fn pack_sequences(
134 &self,
135 sequences: &[TokenizedInput],
136 ) -> Result<(Vec<PackedSequence>, PackingStats)> {
137 if sequences.is_empty() {
138 return Ok((vec![], PackingStats::default()));
139 }
140
141 let mut seq_items: Vec<SequenceItem> = sequences
143 .iter()
144 .enumerate()
145 .map(|(idx, seq)| SequenceItem {
146 index: idx,
147 length: seq.input_ids.len(),
148 tokenized_input: seq.clone(),
149 })
150 .collect();
151
152 seq_items.retain(|item| {
154 item.length >= self.config.min_sequence_length
155 && item.length <= self.config.max_packed_length
156 });
157
158 self.apply_packing_strategy(&mut seq_items);
160
161 let packed_sequences = self.pack_sequences_greedy(&seq_items)?;
163
164 let stats = self.calculate_stats(sequences.len(), &packed_sequences);
166
167 Ok((packed_sequences, stats))
168 }
169
170 pub fn unpack_sequence(&self, packed: &PackedSequence) -> Result<Vec<TokenizedInput>> {
172 let mut sequences = Vec::new();
173
174 for (start, end) in &packed.packing_info.sequence_boundaries {
175 if *end > packed.tokenized_input.input_ids.len() {
176 return Err(TrustformersError::invalid_input(
177 "Invalid sequence boundary in packed sequence".to_string(),
178 ));
179 }
180
181 let input_ids = packed.tokenized_input.input_ids[*start..*end].to_vec();
182 let attention_mask = packed.tokenized_input.attention_mask[*start..*end].to_vec();
183
184 let token_type_ids = packed
185 .tokenized_input
186 .token_type_ids
187 .as_ref()
188 .map(|ttids| ttids[*start..*end].to_vec());
189
190 sequences.push(TokenizedInput {
191 input_ids,
192 attention_mask,
193 token_type_ids,
194 special_tokens_mask: None,
195 offset_mapping: None,
196 overflowing_tokens: None,
197 });
198 }
199
200 Ok(sequences)
201 }
202
203 fn apply_packing_strategy(&self, seq_items: &mut [SequenceItem]) {
205 match self.config.strategy {
206 PackingStrategy::FirstFit => {
207 },
209 PackingStrategy::BestFit => {
210 seq_items.sort_by_key(|item| std::cmp::Reverse(item.length));
212 },
213 PackingStrategy::SimilarLength => {
214 seq_items.sort_by_key(|a| a.length);
216 },
217 PackingStrategy::Random => {
218 use scirs2_core::random::*; use scirs2_core::random::SliceRandom; let mut rng = thread_rng();
222 seq_items.shuffle(rng.rng_mut());
223 },
224 }
225 }
226
227 fn pack_sequences_greedy(&self, seq_items: &[SequenceItem]) -> Result<Vec<PackedSequence>> {
229 let mut packed_sequences = Vec::new();
230 let mut used = vec![false; seq_items.len()];
231
232 for i in 0..seq_items.len() {
233 if used[i] {
234 continue;
235 }
236
237 let mut current_pack = vec![i];
238 let mut current_length = seq_items[i].length;
239 used[i] = true;
240
241 if self.config.add_separators && self.config.sep_token_id.is_some() {
243 current_length += 1; }
245
246 for j in (i + 1)..seq_items.len() {
248 if used[j] || current_pack.len() >= self.config.max_sequences_per_pack {
249 continue;
250 }
251
252 let additional_length = seq_items[j].length;
253 let separator_length =
254 if self.config.add_separators && self.config.sep_token_id.is_some() {
255 1
256 } else {
257 0
258 };
259
260 if current_length + additional_length + separator_length
261 <= self.config.max_packed_length
262 {
263 current_pack.push(j);
264 current_length += additional_length + separator_length;
265 used[j] = true;
266 }
267 }
268
269 let packed = self.create_packed_sequence(¤t_pack, seq_items)?;
271 packed_sequences.push(packed);
272 }
273
274 Ok(packed_sequences)
275 }
276
277 fn create_packed_sequence(
279 &self,
280 indices: &[usize],
281 seq_items: &[SequenceItem],
282 ) -> Result<PackedSequence> {
283 let mut packed_input_ids = Vec::new();
284 let mut packed_attention_mask = Vec::new();
285 let mut packed_token_type_ids: Vec<u32> = Vec::new();
286 let mut sequence_ids = Vec::new();
287 let mut sequence_boundaries = Vec::new();
288
289 for (seq_idx, &item_idx) in indices.iter().enumerate() {
290 let item = &seq_items[item_idx];
291 let start_pos = packed_input_ids.len();
292
293 packed_input_ids.extend(&item.tokenized_input.input_ids);
295 packed_attention_mask.extend(&item.tokenized_input.attention_mask);
296
297 if let Some(ref ttids) = item.tokenized_input.token_type_ids {
299 packed_token_type_ids.extend(ttids);
300 } else {
301 packed_token_type_ids.extend(vec![0u32; item.tokenized_input.input_ids.len()]);
302 }
303
304 sequence_ids.extend(vec![seq_idx as u32; item.tokenized_input.input_ids.len()]);
306
307 let end_pos = packed_input_ids.len();
308 sequence_boundaries.push((start_pos, end_pos));
309
310 if seq_idx < indices.len() - 1 && self.config.add_separators {
312 if let Some(sep_token_id) = self.config.sep_token_id {
313 packed_input_ids.push(sep_token_id);
314 packed_attention_mask.push(1);
315 packed_token_type_ids.push(0u32);
316 sequence_ids.push(seq_idx as u32);
317 }
318 }
319 }
320
321 let current_length = packed_input_ids.len();
323 if current_length < self.config.max_packed_length {
324 let padding_length = self.config.max_packed_length - current_length;
325 packed_input_ids.extend(vec![self.config.pad_token_id; padding_length]);
326 packed_attention_mask.extend(vec![0u8; padding_length]);
327 packed_token_type_ids.extend(vec![0u32; padding_length]);
328 sequence_ids.extend(vec![u32::MAX; padding_length]); }
330
331 let packing_info = PackingInfo {
332 original_indices: indices.iter().map(|&i| seq_items[i].index).collect(),
333 sequence_boundaries,
334 num_sequences: indices.len(),
335 packed_length: current_length,
336 efficiency: current_length as f32 / self.config.max_packed_length as f32,
337 };
338
339 let tokenized_input = TokenizedInput {
340 input_ids: packed_input_ids,
341 attention_mask: packed_attention_mask,
342 token_type_ids: Some(packed_token_type_ids),
343 special_tokens_mask: None,
344 offset_mapping: None,
345 overflowing_tokens: None,
346 };
347
348 Ok(PackedSequence {
349 tokenized_input,
350 packing_info,
351 sequence_ids,
352 })
353 }
354
355 fn calculate_stats(
357 &self,
358 original_count: usize,
359 packed_sequences: &[PackedSequence],
360 ) -> PackingStats {
361 let total_packed_sequences = packed_sequences.len();
362 let total_sequences_packed: usize =
363 packed_sequences.iter().map(|p| p.packing_info.num_sequences).sum();
364
365 let avg_sequences_per_pack = if total_packed_sequences > 0 {
366 total_sequences_packed as f32 / total_packed_sequences as f32
367 } else {
368 0.0
369 };
370
371 let avg_efficiency = if total_packed_sequences > 0 {
372 packed_sequences.iter().map(|p| p.packing_info.efficiency).sum::<f32>()
373 / total_packed_sequences as f32
374 } else {
375 0.0
376 };
377
378 let unpacked_sequences = original_count.saturating_sub(total_sequences_packed);
379
380 let original_tokens_if_padded = original_count * self.config.max_packed_length;
382 let actual_tokens_used: usize =
383 packed_sequences.iter().map(|_p| self.config.max_packed_length).sum();
384 let tokens_saved = original_tokens_if_padded.saturating_sub(actual_tokens_used);
385
386 let compression_ratio = if actual_tokens_used > 0 {
387 original_tokens_if_padded as f32 / actual_tokens_used as f32
388 } else {
389 1.0
390 };
391
392 PackingStats {
393 total_sequences: original_count,
394 num_packed_sequences: total_packed_sequences,
395 avg_sequences_per_pack,
396 avg_efficiency,
397 unpacked_sequences,
398 tokens_saved,
399 compression_ratio,
400 }
401 }
402}
403
404impl Default for PackingStats {
405 fn default() -> Self {
406 Self {
407 total_sequences: 0,
408 num_packed_sequences: 0,
409 avg_sequences_per_pack: 0.0,
410 avg_efficiency: 0.0,
411 unpacked_sequences: 0,
412 tokens_saved: 0,
413 compression_ratio: 1.0,
414 }
415 }
416}
417
418#[derive(Debug, Clone)]
420struct SequenceItem {
421 index: usize,
422 length: usize,
423 tokenized_input: TokenizedInput,
424}
425
426pub struct AdvancedSequencePacker {
428 base_packer: SequencePacker,
429 length_histogram: HashMap<usize, usize>,
430 #[allow(dead_code)]
431 packing_cache: HashMap<Vec<usize>, PackedSequence>,
432}
433
434impl AdvancedSequencePacker {
435 pub fn new(config: PackingConfig) -> Self {
437 Self {
438 base_packer: SequencePacker::new(config),
439 length_histogram: HashMap::new(),
440 packing_cache: HashMap::new(),
441 }
442 }
443
444 pub fn pack_with_optimization(
446 &mut self,
447 sequences: &[TokenizedInput],
448 ) -> Result<(Vec<PackedSequence>, PackingStats)> {
449 self.update_length_histogram(sequences);
451
452 self.base_packer.pack_sequences(sequences)
454 }
455
456 fn update_length_histogram(&mut self, sequences: &[TokenizedInput]) {
458 for seq in sequences {
459 let length = seq.input_ids.len();
460 *self.length_histogram.entry(length).or_insert(0) += 1;
461 }
462 }
463
464 pub fn get_length_stats(&self) -> Vec<(usize, usize)> {
466 let mut stats: Vec<_> =
467 self.length_histogram.iter().map(|(&len, &count)| (len, count)).collect();
468 stats.sort_by_key(|&(len, _)| len);
469 stats
470 }
471
472 pub fn suggest_config(&self) -> PackingConfig {
474 let mut config = self.base_packer.config.clone();
475
476 if !self.length_histogram.is_empty() {
477 let total_sequences: usize = self.length_histogram.values().sum();
479 let mut cumulative = 0;
480 let mut percentile_95 = 0;
481
482 for (&length, &count) in &self.length_histogram {
483 cumulative += count;
484 if cumulative >= (total_sequences * 95) / 100 {
485 percentile_95 = length;
486 break;
487 }
488 }
489
490 if percentile_95 > 0 {
492 config.max_packed_length = (percentile_95 * 2).max(512);
493 }
494
495 let length_variance = self.calculate_length_variance();
497 if length_variance < 100.0 {
498 config.strategy = PackingStrategy::SimilarLength;
499 } else {
500 config.strategy = PackingStrategy::BestFit;
501 }
502 }
503
504 config
505 }
506
507 fn calculate_length_variance(&self) -> f64 {
509 if self.length_histogram.is_empty() {
510 return 0.0;
511 }
512
513 let total_sequences: usize = self.length_histogram.values().sum();
514 let mean: f64 = self
515 .length_histogram
516 .iter()
517 .map(|(&len, &count)| len as f64 * count as f64)
518 .sum::<f64>()
519 / total_sequences as f64;
520
521 let variance: f64 = self
522 .length_histogram
523 .iter()
524 .map(|(&len, &count)| {
525 let diff = len as f64 - mean;
526 diff * diff * count as f64
527 })
528 .sum::<f64>()
529 / total_sequences as f64;
530
531 variance
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 fn create_test_sequence(length: usize) -> TokenizedInput {
540 TokenizedInput {
541 input_ids: (0..length).map(|i| i as u32).collect(),
542 attention_mask: vec![1u8; length],
543 token_type_ids: Some(vec![0u32; length]),
544 special_tokens_mask: None,
545 offset_mapping: None,
546 overflowing_tokens: None,
547 }
548 }
549
550 #[test]
551 fn test_basic_packing() {
552 let config = PackingConfig {
553 max_packed_length: 100,
554 pad_token_id: 0,
555 ..Default::default()
556 };
557 let packer = SequencePacker::new(config);
558
559 let sequences = vec![
560 create_test_sequence(30),
561 create_test_sequence(25),
562 create_test_sequence(40),
563 ];
564
565 let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
566
567 assert!(!packed.is_empty());
568 assert_eq!(stats.total_sequences, 3);
569 }
570
571 #[test]
572 fn test_packing_with_separators() {
573 let config = PackingConfig {
574 max_packed_length: 100,
575 pad_token_id: 0,
576 sep_token_id: Some(999),
577 add_separators: true,
578 ..Default::default()
579 };
580 let packer = SequencePacker::new(config);
581
582 let sequences = vec![create_test_sequence(20), create_test_sequence(20)];
583
584 let (packed, _) = packer.pack_sequences(&sequences).expect("Operation failed in test");
585
586 assert!(!packed.is_empty());
587 assert!(packed[0].tokenized_input.input_ids.contains(&999));
589 }
590
591 #[test]
592 fn test_unpacking() {
593 let config = PackingConfig {
594 max_packed_length: 100,
595 pad_token_id: 0,
596 ..Default::default()
597 };
598 let packer = SequencePacker::new(config);
599
600 let original_sequences = vec![create_test_sequence(30), create_test_sequence(25)];
601
602 let (packed, _) =
603 packer.pack_sequences(&original_sequences).expect("Operation failed in test");
604 let unpacked = packer.unpack_sequence(&packed[0]).expect("Operation failed in test");
605
606 assert_eq!(unpacked.len(), packed[0].packing_info.num_sequences);
607 }
608
609 #[test]
610 fn test_packing_strategies() {
611 let config = PackingConfig {
612 max_packed_length: 100,
613 strategy: PackingStrategy::BestFit,
614 ..Default::default()
615 };
616 let packer = SequencePacker::new(config);
617
618 let sequences = vec![
619 create_test_sequence(80),
620 create_test_sequence(10),
621 create_test_sequence(15),
622 create_test_sequence(20),
623 ];
624
625 let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
626
627 assert!(!packed.is_empty());
628 assert!(stats.avg_efficiency > 0.0);
629 }
630
631 #[test]
632 fn test_advanced_packer() {
633 let config = PackingConfig::default();
634 let mut advanced_packer = AdvancedSequencePacker::new(config);
635
636 let sequences = vec![
637 create_test_sequence(50),
638 create_test_sequence(55),
639 create_test_sequence(48),
640 create_test_sequence(52),
641 ];
642
643 let (packed, stats) = advanced_packer
644 .pack_with_optimization(&sequences)
645 .expect("Operation failed in test");
646
647 assert!(!packed.is_empty());
648 assert_eq!(stats.total_sequences, 4);
649
650 let length_stats = advanced_packer.get_length_stats();
651 assert!(!length_stats.is_empty());
652
653 let suggested_config = advanced_packer.suggest_config();
654 assert!(suggested_config.max_packed_length > 0);
655 }
656
657 #[test]
658 fn test_efficiency_calculation() {
659 let config = PackingConfig {
660 max_packed_length: 100,
661 ..Default::default()
662 };
663 let packer = SequencePacker::new(config);
664
665 let sequences = vec![create_test_sequence(50), create_test_sequence(50)];
667
668 let (packed, stats) = packer.pack_sequences(&sequences).expect("Operation failed in test");
669
670 assert_eq!(packed.len(), 1);
671 assert_eq!(packed[0].packing_info.efficiency, 1.0); assert!(stats.avg_efficiency > 0.9);
673 }
674
675 #[test]
676 fn test_max_sequences_per_pack() {
677 let config = PackingConfig {
678 max_packed_length: 1000,
679 max_sequences_per_pack: 2,
680 ..Default::default()
681 };
682 let packer = SequencePacker::new(config);
683
684 let sequences = vec![
685 create_test_sequence(10),
686 create_test_sequence(10),
687 create_test_sequence(10),
688 create_test_sequence(10),
689 ];
690
691 let (packed, _) = packer.pack_sequences(&sequences).expect("Operation failed in test");
692
693 assert_eq!(packed.len(), 2);
695 for pack in packed {
696 assert!(pack.packing_info.num_sequences <= 2);
697 }
698 }
699}