tokenizers/processors/
bert.rs

1use crate::tokenizer::{Encoding, PostProcessor, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::iter::FromIterator;
5
6#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
7#[serde(tag = "type")]
8pub struct BertProcessing {
9    sep: (String, u32),
10    cls: (String, u32),
11}
12
13impl Default for BertProcessing {
14    fn default() -> Self {
15        Self {
16            sep: ("[SEP]".into(), 102),
17            cls: ("[CLS]".into(), 101),
18        }
19    }
20}
21
22impl BertProcessing {
23    pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
24        Self { sep, cls }
25    }
26}
27
28#[derive(thiserror::Error, Debug)]
29pub enum BertProcessorError {
30    #[error("encodings vector length must be either 1 or 2")]
31    InvalidEncodingsVecLength,
32}
33
34impl PostProcessor for BertProcessing {
35    fn added_tokens(&self, is_pair: bool) -> usize {
36        if is_pair {
37            3
38        } else {
39            2
40        }
41    }
42
43    fn process_encodings(
44        &self,
45        mut encodings: Vec<Encoding>,
46        add_special_tokens: bool,
47    ) -> Result<Vec<Encoding>> {
48        if !add_special_tokens {
49            return Ok(encodings);
50        }
51
52        let encodings: Vec<Encoding> = encodings
53            .iter_mut()
54            .enumerate()
55            .map(|(i, encoding)| {
56                if i == 0 {
57                    let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
58                    let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
59                    let tokens = [
60                        &[self.cls.0.clone()],
61                        encoding.get_tokens(),
62                        &[self.sep.0.clone()],
63                    ]
64                    .concat();
65                    let words = [&[None], encoding.get_word_ids(), &[None]].concat();
66                    let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
67                    let special_tokens =
68                        [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
69                    let attention_mask = vec![1; ids.len()];
70
71                    // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
72                    // the special tokens.
73                    let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
74                    Encoding::new(
75                        ids,
76                        type_ids,
77                        tokens,
78                        words,
79                        offsets,
80                        special_tokens,
81                        attention_mask,
82                        encoding
83                            .take_overflowing()
84                            .into_iter()
85                            .map(|encoding| {
86                                let ids =
87                                    [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
88                                let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
89                                let tokens = [
90                                    &[self.cls.0.clone()],
91                                    encoding.get_tokens(),
92                                    &[self.sep.0.clone()],
93                                ]
94                                .concat();
95                                let words = [&[None], encoding.get_word_ids(), &[None]].concat();
96                                let offsets =
97                                    [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
98                                let special_tokens =
99                                    [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
100                                        .concat();
101                                let attention_mask = vec![1; ids.len()];
102
103                                // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
104                                // contain the special tokens.
105                                let sequence_ranges =
106                                    HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
107                                Encoding::new(
108                                    ids,
109                                    type_ids,
110                                    tokens,
111                                    words,
112                                    offsets,
113                                    special_tokens,
114                                    attention_mask,
115                                    vec![],
116                                    sequence_ranges,
117                                )
118                            })
119                            .collect(),
120                        sequence_ranges,
121                    )
122                } else {
123                    let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
124                    let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
125                    let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
126                    let pair_words = [encoding.get_word_ids(), &[None]].concat();
127                    let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
128                    let pair_special_tokens =
129                        [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
130                    let pair_attention_mask = vec![1; pair_ids.len()];
131
132                    // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
133                    // the special tokens.
134                    let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
135                    Encoding::new(
136                        pair_ids,
137                        pair_type_ids,
138                        pair_tokens,
139                        pair_words,
140                        pair_offsets,
141                        pair_special_tokens,
142                        pair_attention_mask,
143                        encoding
144                            .take_overflowing()
145                            .into_iter()
146                            .map(|encoding| {
147                                let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
148                                let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
149                                let pair_tokens =
150                                    [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
151                                let pair_words = [encoding.get_word_ids(), &[None]].concat();
152                                let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
153                                let pair_special_tokens =
154                                    [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
155                                let pair_attention_mask = vec![1; pair_ids.len()];
156
157                                // For compatibility with `TemplateProcessing`, the sequence_ranges
158                                // shouldn't contain the special tokens.
159                                let pair_sequence_ranges =
160                                    HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
161                                Encoding::new(
162                                    pair_ids,
163                                    pair_type_ids,
164                                    pair_tokens,
165                                    pair_words,
166                                    pair_offsets,
167                                    pair_special_tokens,
168                                    pair_attention_mask,
169                                    vec![],
170                                    pair_sequence_ranges,
171                                )
172                            })
173                            .collect(),
174                        pair_sequence_ranges,
175                    )
176                }
177            })
178            .collect();
179
180        Ok(encodings)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn serde() {
190        let bert = BertProcessing::default();
191        let bert_r = r#"{"type":"BertProcessing","sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
192        assert_eq!(serde_json::to_string(&bert).unwrap(), bert_r);
193        assert_eq!(
194            serde_json::from_str::<BertProcessing>(bert_r).unwrap(),
195            bert
196        );
197    }
198
199    #[test]
200    fn bert_processing() {
201        let processor = BertProcessing::default();
202        assert_eq!(processor.added_tokens(false), 2);
203        assert_eq!(processor.added_tokens(true), 3);
204
205        use crate::Token;
206        let encoding = Encoding::from_tokens(
207            vec![
208                Token::new(12, "Hello".into(), (0, 5)),
209                Token::new(14, "there".into(), (6, 11)),
210            ],
211            0,
212        );
213        let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
214        let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
215        assert_eq!(
216            single_encoding,
217            Encoding::new(
218                vec![101, 12, 14, 102],
219                vec![0, 0, 0, 0],
220                vec![
221                    "[CLS]".into(),
222                    "Hello".into(),
223                    "there".into(),
224                    "[SEP]".into()
225                ],
226                vec![None, None, None, None],
227                vec![(0, 0), (0, 5), (6, 11), (0, 0)],
228                vec![1, 0, 0, 1],
229                vec![1, 1, 1, 1],
230                vec![],
231                HashMap::from_iter(vec![(0, 1..3)]),
232            )
233        );
234        assert_eq!(single_encoding.token_to_sequence(2), Some(0));
235        assert_eq!(single_encoding.token_to_sequence(3), None);
236        let pair_encoding = processor
237            .process(encoding.clone(), Some(pair.clone()), true)
238            .unwrap();
239        assert_eq!(
240            pair_encoding,
241            Encoding::new(
242                vec![101, 12, 14, 102, 15, 102],
243                vec![0, 0, 0, 0, 1, 1],
244                vec![
245                    "[CLS]".into(),
246                    "Hello".into(),
247                    "there".into(),
248                    "[SEP]".into(),
249                    "pair".into(),
250                    "[SEP]".into()
251                ],
252                vec![None, None, None, None, None, None],
253                vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)],
254                vec![1, 0, 0, 1, 0, 1],
255                vec![1, 1, 1, 1, 1, 1],
256                vec![],
257                HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
258            )
259        );
260        assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
261        assert_eq!(pair_encoding.token_to_sequence(3), None);
262        assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
263        assert_eq!(pair_encoding.token_to_sequence(5), None);
264
265        // No special tokens
266        let pair_encoding = processor.process(encoding, Some(pair), false).unwrap();
267        assert_eq!(
268            pair_encoding,
269            Encoding::new(
270                vec![12, 14, 15],
271                vec![0, 0, 1],
272                vec!["Hello".into(), "there".into(), "pair".into(),],
273                vec![None, None, None],
274                vec![(0, 5), (6, 11), (0, 4)],
275                vec![0, 0, 0],
276                vec![1, 1, 1],
277                vec![],
278                HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]),
279            )
280        );
281        assert_eq!(pair_encoding.token_to_sequence(0), Some(0));
282        assert_eq!(pair_encoding.token_to_sequence(1), Some(0));
283        assert_eq!(pair_encoding.token_to_sequence(2), Some(1));
284    }
285}