1use crate::tokenizer::{Encoding, PostProcessor, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::iter::FromIterator;
5
6#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
7#[serde(tag = "type")]
8pub struct BertProcessing {
9 sep: (String, u32),
10 cls: (String, u32),
11}
12
13impl Default for BertProcessing {
14 fn default() -> Self {
15 Self {
16 sep: ("[SEP]".into(), 102),
17 cls: ("[CLS]".into(), 101),
18 }
19 }
20}
21
22impl BertProcessing {
23 pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
24 Self { sep, cls }
25 }
26}
27
28#[derive(thiserror::Error, Debug)]
29pub enum BertProcessorError {
30 #[error("encodings vector length must be either 1 or 2")]
31 InvalidEncodingsVecLength,
32}
33
34impl PostProcessor for BertProcessing {
35 fn added_tokens(&self, is_pair: bool) -> usize {
36 if is_pair {
37 3
38 } else {
39 2
40 }
41 }
42
43 fn process_encodings(
44 &self,
45 mut encodings: Vec<Encoding>,
46 add_special_tokens: bool,
47 ) -> Result<Vec<Encoding>> {
48 if !add_special_tokens {
49 return Ok(encodings);
50 }
51
52 let encodings: Vec<Encoding> = encodings
53 .iter_mut()
54 .enumerate()
55 .map(|(i, encoding)| {
56 if i == 0 {
57 let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
58 let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
59 let tokens = [
60 &[self.cls.0.clone()],
61 encoding.get_tokens(),
62 &[self.sep.0.clone()],
63 ]
64 .concat();
65 let words = [&[None], encoding.get_word_ids(), &[None]].concat();
66 let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
67 let special_tokens =
68 [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
69 let attention_mask = vec![1; ids.len()];
70
71 let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
74 Encoding::new(
75 ids,
76 type_ids,
77 tokens,
78 words,
79 offsets,
80 special_tokens,
81 attention_mask,
82 encoding
83 .take_overflowing()
84 .into_iter()
85 .map(|encoding| {
86 let ids =
87 [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
88 let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
89 let tokens = [
90 &[self.cls.0.clone()],
91 encoding.get_tokens(),
92 &[self.sep.0.clone()],
93 ]
94 .concat();
95 let words = [&[None], encoding.get_word_ids(), &[None]].concat();
96 let offsets =
97 [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
98 let special_tokens =
99 [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
100 .concat();
101 let attention_mask = vec![1; ids.len()];
102
103 let sequence_ranges =
106 HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
107 Encoding::new(
108 ids,
109 type_ids,
110 tokens,
111 words,
112 offsets,
113 special_tokens,
114 attention_mask,
115 vec![],
116 sequence_ranges,
117 )
118 })
119 .collect(),
120 sequence_ranges,
121 )
122 } else {
123 let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
124 let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
125 let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
126 let pair_words = [encoding.get_word_ids(), &[None]].concat();
127 let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
128 let pair_special_tokens =
129 [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
130 let pair_attention_mask = vec![1; pair_ids.len()];
131
132 let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
135 Encoding::new(
136 pair_ids,
137 pair_type_ids,
138 pair_tokens,
139 pair_words,
140 pair_offsets,
141 pair_special_tokens,
142 pair_attention_mask,
143 encoding
144 .take_overflowing()
145 .into_iter()
146 .map(|encoding| {
147 let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
148 let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
149 let pair_tokens =
150 [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
151 let pair_words = [encoding.get_word_ids(), &[None]].concat();
152 let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
153 let pair_special_tokens =
154 [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
155 let pair_attention_mask = vec![1; pair_ids.len()];
156
157 let pair_sequence_ranges =
160 HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
161 Encoding::new(
162 pair_ids,
163 pair_type_ids,
164 pair_tokens,
165 pair_words,
166 pair_offsets,
167 pair_special_tokens,
168 pair_attention_mask,
169 vec![],
170 pair_sequence_ranges,
171 )
172 })
173 .collect(),
174 pair_sequence_ranges,
175 )
176 }
177 })
178 .collect();
179
180 Ok(encodings)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn serde() {
190 let bert = BertProcessing::default();
191 let bert_r = r#"{"type":"BertProcessing","sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
192 assert_eq!(serde_json::to_string(&bert).unwrap(), bert_r);
193 assert_eq!(
194 serde_json::from_str::<BertProcessing>(bert_r).unwrap(),
195 bert
196 );
197 }
198
199 #[test]
200 fn bert_processing() {
201 let processor = BertProcessing::default();
202 assert_eq!(processor.added_tokens(false), 2);
203 assert_eq!(processor.added_tokens(true), 3);
204
205 use crate::Token;
206 let encoding = Encoding::from_tokens(
207 vec![
208 Token::new(12, "Hello".into(), (0, 5)),
209 Token::new(14, "there".into(), (6, 11)),
210 ],
211 0,
212 );
213 let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
214 let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
215 assert_eq!(
216 single_encoding,
217 Encoding::new(
218 vec![101, 12, 14, 102],
219 vec![0, 0, 0, 0],
220 vec![
221 "[CLS]".into(),
222 "Hello".into(),
223 "there".into(),
224 "[SEP]".into()
225 ],
226 vec![None, None, None, None],
227 vec![(0, 0), (0, 5), (6, 11), (0, 0)],
228 vec![1, 0, 0, 1],
229 vec![1, 1, 1, 1],
230 vec![],
231 HashMap::from_iter(vec![(0, 1..3)]),
232 )
233 );
234 assert_eq!(single_encoding.token_to_sequence(2), Some(0));
235 assert_eq!(single_encoding.token_to_sequence(3), None);
236 let pair_encoding = processor
237 .process(encoding.clone(), Some(pair.clone()), true)
238 .unwrap();
239 assert_eq!(
240 pair_encoding,
241 Encoding::new(
242 vec![101, 12, 14, 102, 15, 102],
243 vec![0, 0, 0, 0, 1, 1],
244 vec![
245 "[CLS]".into(),
246 "Hello".into(),
247 "there".into(),
248 "[SEP]".into(),
249 "pair".into(),
250 "[SEP]".into()
251 ],
252 vec![None, None, None, None, None, None],
253 vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)],
254 vec![1, 0, 0, 1, 0, 1],
255 vec![1, 1, 1, 1, 1, 1],
256 vec![],
257 HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
258 )
259 );
260 assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
261 assert_eq!(pair_encoding.token_to_sequence(3), None);
262 assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
263 assert_eq!(pair_encoding.token_to_sequence(5), None);
264
265 let pair_encoding = processor.process(encoding, Some(pair), false).unwrap();
267 assert_eq!(
268 pair_encoding,
269 Encoding::new(
270 vec![12, 14, 15],
271 vec![0, 0, 1],
272 vec!["Hello".into(), "there".into(), "pair".into(),],
273 vec![None, None, None],
274 vec![(0, 5), (6, 11), (0, 4)],
275 vec![0, 0, 0],
276 vec![1, 1, 1],
277 vec![],
278 HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]),
279 )
280 );
281 assert_eq!(pair_encoding.token_to_sequence(0), Some(0));
282 assert_eq!(pair_encoding.token_to_sequence(1), Some(0));
283 assert_eq!(pair_encoding.token_to_sequence(2), Some(1));
284 }
285}