1use crate::tokenizer::TransformerTokenizer;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum PaddingStrategy {
37 NoPadding,
39 MaxLength,
41 LongestInBatch,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum TruncationStrategy {
48 NoTruncation,
50 Right,
52 Left,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum PaddingSide {
59 Right,
61 Left,
63}
64
65#[derive(Debug, Clone)]
67pub struct BatchConfig {
68 pub max_length: Option<usize>,
71 pub padding: PaddingStrategy,
73 pub truncation: TruncationStrategy,
75 pub pad_token_id: u32,
77}
78
79impl Default for BatchConfig {
80 fn default() -> Self {
81 Self {
82 max_length: None,
83 padding: PaddingStrategy::LongestInBatch,
84 truncation: TruncationStrategy::Right,
85 pad_token_id: 0,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct BatchConfigExt {
93 pub base: BatchConfig,
95 pub padding_side: PaddingSide,
97}
98
99impl Default for BatchConfigExt {
100 fn default() -> Self {
101 Self {
102 base: BatchConfig::default(),
103 padding_side: PaddingSide::Right,
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
118pub struct BatchEncoding {
119 pub input_ids: Vec<Vec<u32>>,
121 pub attention_mask: Vec<Vec<u32>>,
123 pub lengths: Vec<usize>,
125}
126
127impl BatchEncoding {
128 pub fn batch_size(&self) -> usize {
130 self.input_ids.len()
131 }
132
133 pub fn seq_length(&self) -> usize {
135 self.input_ids.first().map_or(0, |v| v.len())
136 }
137
138 pub fn total_real_tokens(&self) -> usize {
140 self.attention_mask
141 .iter()
142 .flat_map(|mask| mask.iter())
143 .filter(|&&v| v == 1)
144 .count()
145 }
146}
147
148fn truncate(ids: &[u32], strategy: TruncationStrategy, max_length: usize) -> Vec<u32> {
154 if ids.len() <= max_length {
155 return ids.to_vec();
156 }
157 match strategy {
158 TruncationStrategy::NoTruncation => ids.to_vec(),
159 TruncationStrategy::Right => ids[..max_length].to_vec(),
160 TruncationStrategy::Left => ids[ids.len() - max_length..].to_vec(),
161 }
162}
163
164fn pad_right(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
166 let real_len = ids.len();
167 if real_len >= target_length {
168 let truncated = &ids[..target_length];
169 let mask = vec![1u32; target_length];
170 return (truncated.to_vec(), mask);
171 }
172 let mut padded = ids.to_vec();
173 let mut mask = vec![1u32; real_len];
174 let pad_count = target_length - real_len;
175 padded.extend(std::iter::repeat_n(pad_id, pad_count));
176 mask.extend(std::iter::repeat_n(0u32, pad_count));
177 (padded, mask)
178}
179
180fn pad_left(ids: &[u32], target_length: usize, pad_id: u32) -> (Vec<u32>, Vec<u32>) {
182 let real_len = ids.len();
183 if real_len >= target_length {
184 let start = real_len - target_length;
185 let truncated = &ids[start..];
186 let mask = vec![1u32; target_length];
187 return (truncated.to_vec(), mask);
188 }
189 let pad_count = target_length - real_len;
190 let mut padded: Vec<u32> = std::iter::repeat_n(pad_id, pad_count).collect();
191 let mut mask: Vec<u32> = std::iter::repeat_n(0u32, pad_count).collect();
192 padded.extend_from_slice(ids);
193 mask.extend(std::iter::repeat_n(1u32, real_len));
194 (padded, mask)
195}
196
197pub fn batch_encode<T: TransformerTokenizer>(
210 texts: &[&str],
211 tokenizer: &T,
212 config: &BatchConfig,
213) -> BatchEncoding {
214 if texts.is_empty() {
215 return BatchEncoding {
216 input_ids: Vec::new(),
217 attention_mask: Vec::new(),
218 lengths: Vec::new(),
219 };
220 }
221
222 let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
224 let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
225
226 if let Some(max_len) = config.max_length {
228 if config.truncation != TruncationStrategy::NoTruncation {
229 for seq in &mut encoded {
230 *seq = truncate(seq, config.truncation, max_len);
231 }
232 }
233 }
234
235 let target_length = match config.padding {
237 PaddingStrategy::NoPadding => {
238 let attention_mask: Vec<Vec<u32>> =
240 encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
241 return BatchEncoding {
242 input_ids: encoded,
243 attention_mask,
244 lengths: original_lengths,
245 };
246 }
247 PaddingStrategy::MaxLength => config
248 .max_length
249 .unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
250 PaddingStrategy::LongestInBatch => {
251 let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
252 match config.max_length {
254 Some(ml) => longest.min(ml),
255 None => longest,
256 }
257 }
258 };
259
260 let mut input_ids = Vec::with_capacity(encoded.len());
262 let mut attention_mask = Vec::with_capacity(encoded.len());
263
264 for seq in &encoded {
265 let (padded, mask) = pad_right(seq, target_length, config.pad_token_id);
266 input_ids.push(padded);
267 attention_mask.push(mask);
268 }
269
270 BatchEncoding {
271 input_ids,
272 attention_mask,
273 lengths: original_lengths,
274 }
275}
276
277pub fn batch_encode_ext<T: TransformerTokenizer>(
281 texts: &[&str],
282 tokenizer: &T,
283 config: &BatchConfigExt,
284) -> BatchEncoding {
285 if texts.is_empty() {
286 return BatchEncoding {
287 input_ids: Vec::new(),
288 attention_mask: Vec::new(),
289 lengths: Vec::new(),
290 };
291 }
292
293 let mut encoded: Vec<Vec<u32>> = texts.iter().map(|t| tokenizer.encode(t)).collect();
295 let original_lengths: Vec<usize> = encoded.iter().map(|v| v.len()).collect();
296
297 if let Some(max_len) = config.base.max_length {
299 if config.base.truncation != TruncationStrategy::NoTruncation {
300 for seq in &mut encoded {
301 *seq = truncate(seq, config.base.truncation, max_len);
302 }
303 }
304 }
305
306 let target_length = match config.base.padding {
308 PaddingStrategy::NoPadding => {
309 let attention_mask: Vec<Vec<u32>> =
310 encoded.iter().map(|seq| vec![1u32; seq.len()]).collect();
311 return BatchEncoding {
312 input_ids: encoded,
313 attention_mask,
314 lengths: original_lengths,
315 };
316 }
317 PaddingStrategy::MaxLength => config
318 .base
319 .max_length
320 .unwrap_or_else(|| encoded.iter().map(|s| s.len()).max().unwrap_or(0)),
321 PaddingStrategy::LongestInBatch => {
322 let longest = encoded.iter().map(|s| s.len()).max().unwrap_or(0);
323 match config.base.max_length {
324 Some(ml) => longest.min(ml),
325 None => longest,
326 }
327 }
328 };
329
330 let pad_fn = match config.padding_side {
332 PaddingSide::Right => pad_right,
333 PaddingSide::Left => pad_left,
334 };
335
336 let mut input_ids = Vec::with_capacity(encoded.len());
337 let mut attention_mask = Vec::with_capacity(encoded.len());
338
339 for seq in &encoded {
340 let (padded, mask) = pad_fn(seq, target_length, config.base.pad_token_id);
341 input_ids.push(padded);
342 attention_mask.push(mask);
343 }
344
345 BatchEncoding {
346 input_ids,
347 attention_mask,
348 lengths: original_lengths,
349 }
350}
351
352#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::tokenizer::BPETokenizer;
360
361 fn train_tokenizer() -> BPETokenizer {
362 let corpus = &[
363 "the cat sat on the mat",
364 "the dog sat on the log",
365 "cats and dogs",
366 "the quick brown fox",
367 ];
368 BPETokenizer::train(corpus, 100).expect("training should succeed")
369 }
370
371 #[test]
372 fn test_batch_encode_basic() {
373 let tok = train_tokenizer();
374 let texts = &["the cat", "the dog sat"];
375 let config = BatchConfig {
376 padding: PaddingStrategy::LongestInBatch,
377 ..Default::default()
378 };
379 let batch = batch_encode(texts, &tok, &config);
380 assert_eq!(batch.batch_size(), 2);
381 assert_eq!(batch.input_ids[0].len(), batch.input_ids[1].len());
383 assert_eq!(batch.attention_mask[0].len(), batch.attention_mask[1].len());
384 }
385
386 #[test]
387 fn test_padding_adds_correct_tokens() {
388 let tok = train_tokenizer();
389 let texts = &["the", "the cat sat on the mat"];
390 let config = BatchConfig {
391 padding: PaddingStrategy::LongestInBatch,
392 pad_token_id: 0,
393 ..Default::default()
394 };
395 let batch = batch_encode(texts, &tok, &config);
396
397 let shorter_len = batch.lengths[0];
399 let padded_len = batch.input_ids[0].len();
400 if shorter_len < padded_len {
401 for i in shorter_len..padded_len {
403 assert_eq!(
404 batch.input_ids[0][i], 0,
405 "padding token should be 0 at position {i}"
406 );
407 }
408 }
409 }
410
411 #[test]
412 fn test_attention_mask_correct() {
413 let tok = train_tokenizer();
414 let texts = &["the", "the cat sat"];
415 let config = BatchConfig {
416 padding: PaddingStrategy::LongestInBatch,
417 pad_token_id: 0,
418 ..Default::default()
419 };
420 let batch = batch_encode(texts, &tok, &config);
421
422 let shorter_len = batch.lengths[0];
424 for i in 0..shorter_len.min(batch.attention_mask[0].len()) {
425 assert_eq!(
426 batch.attention_mask[0][i], 1,
427 "real token at {i} should have mask 1"
428 );
429 }
430 for i in shorter_len..batch.attention_mask[0].len() {
431 assert_eq!(
432 batch.attention_mask[0][i], 0,
433 "padding at {i} should have mask 0"
434 );
435 }
436 }
437
438 #[test]
439 fn test_truncation_right() {
440 let tok = train_tokenizer();
441 let texts = &["the cat sat on the mat"];
442 let config = BatchConfig {
443 max_length: Some(3),
444 padding: PaddingStrategy::NoPadding,
445 truncation: TruncationStrategy::Right,
446 pad_token_id: 0,
447 };
448 let batch = batch_encode(texts, &tok, &config);
449 assert!(
450 batch.input_ids[0].len() <= 3,
451 "truncated length should be <= 3, got {}",
452 batch.input_ids[0].len()
453 );
454 }
455
456 #[test]
457 fn test_truncation_left() {
458 let tok = train_tokenizer();
459 let texts = &["the cat sat on the mat"];
460 let config = BatchConfig {
461 max_length: Some(3),
462 padding: PaddingStrategy::NoPadding,
463 truncation: TruncationStrategy::Left,
464 pad_token_id: 0,
465 };
466 let batch = batch_encode(texts, &tok, &config);
467 assert!(
468 batch.input_ids[0].len() <= 3,
469 "truncated length should be <= 3"
470 );
471 }
472
473 #[test]
474 fn test_no_padding_varying_lengths() {
475 let tok = train_tokenizer();
476 let texts = &["the", "the cat sat"];
477 let config = BatchConfig {
478 padding: PaddingStrategy::NoPadding,
479 truncation: TruncationStrategy::NoTruncation,
480 ..Default::default()
481 };
482 let batch = batch_encode(texts, &tok, &config);
483 assert_eq!(batch.input_ids[0].len(), batch.lengths[0]);
485 assert_eq!(batch.input_ids[1].len(), batch.lengths[1]);
486 }
487
488 #[test]
489 fn test_max_length_padding() {
490 let tok = train_tokenizer();
491 let texts = &["the"];
492 let config = BatchConfig {
493 max_length: Some(10),
494 padding: PaddingStrategy::MaxLength,
495 truncation: TruncationStrategy::Right,
496 pad_token_id: 0,
497 };
498 let batch = batch_encode(texts, &tok, &config);
499 assert_eq!(
500 batch.input_ids[0].len(),
501 10,
502 "should be padded to max_length"
503 );
504 }
505
506 #[test]
507 fn test_empty_input() {
508 let tok = train_tokenizer();
509 let texts: &[&str] = &[];
510 let config = BatchConfig::default();
511 let batch = batch_encode(texts, &tok, &config);
512 assert_eq!(batch.batch_size(), 0);
513 }
514
515 #[test]
516 fn test_empty_string_in_batch() {
517 let tok = train_tokenizer();
518 let texts = &["", "the cat"];
519 let config = BatchConfig {
520 padding: PaddingStrategy::LongestInBatch,
521 pad_token_id: 0,
522 ..Default::default()
523 };
524 let batch = batch_encode(texts, &tok, &config);
525 assert_eq!(batch.batch_size(), 2);
526 assert_eq!(batch.lengths[0], 0);
528 }
529
530 #[test]
531 fn test_left_padding() {
532 let tok = train_tokenizer();
533 let texts = &["the", "the cat sat"];
534 let config = BatchConfigExt {
535 base: BatchConfig {
536 padding: PaddingStrategy::LongestInBatch,
537 pad_token_id: 0,
538 ..Default::default()
539 },
540 padding_side: PaddingSide::Left,
541 };
542 let batch = batch_encode_ext(texts, &tok, &config);
543
544 let shorter_len = batch.lengths[0];
546 let total_len = batch.input_ids[0].len();
547 if shorter_len < total_len {
548 let pad_count = total_len - shorter_len;
549 for i in 0..pad_count {
550 assert_eq!(
551 batch.attention_mask[0][i], 0,
552 "left padding mask at {i} should be 0"
553 );
554 assert_eq!(
555 batch.input_ids[0][i], 0,
556 "left padding token at {i} should be pad_id"
557 );
558 }
559 for i in pad_count..total_len {
560 assert_eq!(
561 batch.attention_mask[0][i], 1,
562 "real token mask at {i} should be 1"
563 );
564 }
565 }
566 }
567
568 #[test]
569 fn test_total_real_tokens() {
570 let tok = train_tokenizer();
571 let texts = &["the cat", "the"];
572 let config = BatchConfig {
573 padding: PaddingStrategy::LongestInBatch,
574 pad_token_id: 0,
575 ..Default::default()
576 };
577 let batch = batch_encode(texts, &tok, &config);
578 let total = batch.total_real_tokens();
579 let expected: usize = batch.lengths.iter().sum();
580 assert_eq!(
581 total, expected,
582 "total real tokens should equal sum of lengths"
583 );
584 }
585
586 #[test]
587 fn test_truncation_with_padding() {
588 let tok = train_tokenizer();
589 let texts = &["the cat sat on the mat", "the"];
590 let config = BatchConfig {
591 max_length: Some(4),
592 padding: PaddingStrategy::MaxLength,
593 truncation: TruncationStrategy::Right,
594 pad_token_id: 0,
595 };
596 let batch = batch_encode(texts, &tok, &config);
597 for seq in &batch.input_ids {
599 assert_eq!(seq.len(), 4, "all sequences should be padded to max_length");
600 }
601 }
602}