tokenizers/utils/
padding.rs1use crate::parallelism::*;
2use crate::tokenizer::{Encoding, Result};
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
7pub enum PaddingDirection {
8 Left,
9 Right,
10}
11
12impl std::convert::AsRef<str> for PaddingDirection {
13 fn as_ref(&self) -> &str {
14 match self {
15 PaddingDirection::Left => "left",
16 PaddingDirection::Right => "right",
17 }
18 }
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct PaddingParams {
23 pub strategy: PaddingStrategy,
24 pub direction: PaddingDirection,
25 pub pad_to_multiple_of: Option<usize>,
26 pub pad_id: u32,
27 pub pad_type_id: u32,
28 pub pad_token: String,
29}
30
31impl Default for PaddingParams {
32 fn default() -> Self {
33 Self {
34 strategy: PaddingStrategy::BatchLongest,
35 direction: PaddingDirection::Right,
36 pad_to_multiple_of: None,
37 pad_id: 0,
38 pad_type_id: 0,
39 pad_token: String::from("[PAD]"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum PaddingStrategy {
46 BatchLongest,
47 Fixed(usize),
48}
49
50pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Result<()> {
51 if encodings.is_empty() {
52 return Ok(());
53 }
54
55 let mut pad_length = match params.strategy {
56 PaddingStrategy::Fixed(size) => size,
57 PaddingStrategy::BatchLongest => encodings
58 .maybe_par_iter()
59 .map(|e| e.get_ids().len())
60 .max()
61 .unwrap(),
62 };
63
64 if let Some(multiple) = params.pad_to_multiple_of {
65 if multiple > 0 && pad_length % multiple > 0 {
66 pad_length += multiple - pad_length % multiple;
67 }
68 }
69
70 encodings.maybe_par_iter_mut().for_each(|encoding| {
71 encoding.pad(
72 pad_length,
73 params.pad_id,
74 params.pad_type_id,
75 ¶ms.pad_token,
76 params.direction,
77 )
78 });
79
80 Ok(())
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::tokenizer::Encoding;
87 use std::collections::HashMap;
88
89 #[test]
90 fn pad_to_multiple() {
91 fn get_encodings() -> [Encoding; 2] {
92 [
93 Encoding::new(
94 vec![0, 1, 2, 3, 4],
95 vec![],
96 vec![],
97 vec![],
98 vec![],
99 vec![],
100 vec![],
101 vec![],
102 HashMap::new(),
103 ),
104 Encoding::new(
105 vec![0, 1, 2],
106 vec![],
107 vec![],
108 vec![],
109 vec![],
110 vec![],
111 vec![],
112 vec![],
113 HashMap::new(),
114 ),
115 ]
116 }
117
118 let mut encodings = get_encodings();
120 let mut params = PaddingParams {
121 strategy: PaddingStrategy::Fixed(7),
122 direction: PaddingDirection::Right,
123 pad_to_multiple_of: Some(8),
124 pad_id: 0,
125 pad_type_id: 0,
126 pad_token: String::from("[PAD]"),
127 };
128 pad_encodings(&mut encodings, ¶ms).unwrap();
129 assert!(encodings.iter().all(|e| e.get_ids().len() == 8));
130
131 let mut encodings = get_encodings();
133 params.strategy = PaddingStrategy::BatchLongest;
134 params.pad_to_multiple_of = Some(6);
135 pad_encodings(&mut encodings, ¶ms).unwrap();
136 assert!(encodings.iter().all(|e| e.get_ids().len() == 6));
137
138 params.pad_to_multiple_of = Some(0);
140 pad_encodings(&mut encodings, ¶ms).unwrap();
141 }
142}