tokenizers/utils/
padding.rs

1use crate::parallelism::*;
2use crate::tokenizer::{Encoding, Result};
3use serde::{Deserialize, Serialize};
4
5/// The various possible padding directions.
6#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
7pub enum PaddingDirection {
8    Left,
9    Right,
10}
11
12impl std::convert::AsRef<str> for PaddingDirection {
13    fn as_ref(&self) -> &str {
14        match self {
15            PaddingDirection::Left => "left",
16            PaddingDirection::Right => "right",
17        }
18    }
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct PaddingParams {
23    pub strategy: PaddingStrategy,
24    pub direction: PaddingDirection,
25    pub pad_to_multiple_of: Option<usize>,
26    pub pad_id: u32,
27    pub pad_type_id: u32,
28    pub pad_token: String,
29}
30
31impl Default for PaddingParams {
32    fn default() -> Self {
33        Self {
34            strategy: PaddingStrategy::BatchLongest,
35            direction: PaddingDirection::Right,
36            pad_to_multiple_of: None,
37            pad_id: 0,
38            pad_type_id: 0,
39            pad_token: String::from("[PAD]"),
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum PaddingStrategy {
46    BatchLongest,
47    Fixed(usize),
48}
49
50pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Result<()> {
51    if encodings.is_empty() {
52        return Ok(());
53    }
54
55    let mut pad_length = match params.strategy {
56        PaddingStrategy::Fixed(size) => size,
57        PaddingStrategy::BatchLongest => encodings
58            .maybe_par_iter()
59            .map(|e| e.get_ids().len())
60            .max()
61            .unwrap(),
62    };
63
64    if let Some(multiple) = params.pad_to_multiple_of {
65        if multiple > 0 && pad_length % multiple > 0 {
66            pad_length += multiple - pad_length % multiple;
67        }
68    }
69
70    encodings.maybe_par_iter_mut().for_each(|encoding| {
71        encoding.pad(
72            pad_length,
73            params.pad_id,
74            params.pad_type_id,
75            &params.pad_token,
76            params.direction,
77        )
78    });
79
80    Ok(())
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use crate::tokenizer::Encoding;
87    use std::collections::HashMap;
88
89    #[test]
90    fn pad_to_multiple() {
91        fn get_encodings() -> [Encoding; 2] {
92            [
93                Encoding::new(
94                    vec![0, 1, 2, 3, 4],
95                    vec![],
96                    vec![],
97                    vec![],
98                    vec![],
99                    vec![],
100                    vec![],
101                    vec![],
102                    HashMap::new(),
103                ),
104                Encoding::new(
105                    vec![0, 1, 2],
106                    vec![],
107                    vec![],
108                    vec![],
109                    vec![],
110                    vec![],
111                    vec![],
112                    vec![],
113                    HashMap::new(),
114                ),
115            ]
116        }
117
118        // Test fixed
119        let mut encodings = get_encodings();
120        let mut params = PaddingParams {
121            strategy: PaddingStrategy::Fixed(7),
122            direction: PaddingDirection::Right,
123            pad_to_multiple_of: Some(8),
124            pad_id: 0,
125            pad_type_id: 0,
126            pad_token: String::from("[PAD]"),
127        };
128        pad_encodings(&mut encodings, &params).unwrap();
129        assert!(encodings.iter().all(|e| e.get_ids().len() == 8));
130
131        // Test batch
132        let mut encodings = get_encodings();
133        params.strategy = PaddingStrategy::BatchLongest;
134        params.pad_to_multiple_of = Some(6);
135        pad_encodings(&mut encodings, &params).unwrap();
136        assert!(encodings.iter().all(|e| e.get_ids().len() == 6));
137
138        // Do not crash with 0
139        params.pad_to_multiple_of = Some(0);
140        pad_encodings(&mut encodings, &params).unwrap();
141    }
142}