tokenizers/pre_tokenizers/
byte_level.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::utils::SysRegex;
4use serde::{Deserialize, Serialize};
5
6use crate::tokenizer::{
7    Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
8    SplitDelimiterBehavior,
9};
10use crate::utils::macro_rules_attribute;
11
12/// Converts bytes to unicode characters.
13/// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
14pub(crate) fn bytes_char() -> HashMap<u8, char> {
15    let mut bs: Vec<u8> = vec![];
16    bs.extend(b'!'..=b'~');
17    bs.extend(b'\xA1'..=b'\xAC');
18    bs.extend(b'\xAE'..=b'\xFF');
19
20    let mut cs: Vec<u32> = bs.iter().map(|i| *i as u32).collect();
21    let mut n = 0;
22
23    for b in 0..=255u8 {
24        if !bs.contains(&b) {
25            bs.push(b);
26            cs.push(u32::pow(2, 8) + n);
27            n += 1;
28        }
29    }
30
31    // Safety: cs contains all values from bs (between 0 and 255),
32    // and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512.
33    // Both ranges are valid UTF-32 values (which is fully saturated until 0xD000)
34    bs.into_iter()
35        .zip(cs)
36        .map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
37        .collect()
38}
39
40lazy_static! {
41    /// Regex that matches exactly one token.
42    /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L98
43    static ref RE: SysRegex = SysRegex::new(
44        r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
45    )
46    .unwrap();
47    static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
48    static ref CHAR_BYTES: HashMap<char, u8> =
49        bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq)]
53/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
54/// of all the required processing steps to transform a UTF-8 string as needed before and after the
55/// BPE model does its job.
56#[macro_rules_attribute(impl_serde_type!)]
57#[non_exhaustive]
58pub struct ByteLevel {
59    /// Whether to add a leading space to the first word. This allows to treat the leading word
60    /// just as any other word.
61    pub add_prefix_space: bool,
62    /// Whether the post processing step should trim offsets to avoid including whitespaces.
63    pub trim_offsets: bool,
64
65    /// Whether to use the standard GPT2 regex for whitespace splitting
66    /// Set it to False if you want to use your own splitting.
67    #[serde(default = "default_true")]
68    pub use_regex: bool,
69}
70
71fn default_true() -> bool {
72    true
73}
74
75impl Default for ByteLevel {
76    fn default() -> Self {
77        Self {
78            add_prefix_space: true,
79            trim_offsets: true,
80            use_regex: true,
81        }
82    }
83}
84
85impl ByteLevel {
86    pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self {
87        Self {
88            add_prefix_space,
89            trim_offsets,
90            use_regex,
91        }
92    }
93
94    pub fn alphabet() -> HashSet<char> {
95        BYTES_CHAR.values().copied().collect()
96    }
97
98    #[must_use]
99    pub fn add_prefix_space(mut self, v: bool) -> Self {
100        self.add_prefix_space = v;
101        self
102    }
103
104    #[must_use]
105    pub fn trim_offsets(mut self, v: bool) -> Self {
106        self.trim_offsets = v;
107        self
108    }
109
110    #[must_use]
111    pub fn use_regex(mut self, v: bool) -> Self {
112        self.use_regex = v;
113        self
114    }
115}
116
117/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
118/// their byte-level counterpart. It also splits the input according to the configured regex.
119// TODO: Give the ability to modify this regex
120impl PreTokenizer for ByteLevel {
121    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
122        let re_ref: &SysRegex = &RE;
123        pretokenized.split(|_, mut normalized| {
124            if self.add_prefix_space && !normalized.get().starts_with(' ') {
125                normalized.prepend(" ");
126            }
127            if self.use_regex {
128                normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
129            } else {
130                Ok(vec![normalized])
131            }
132        })?;
133        pretokenized.normalize(|normalized| {
134            let s = normalized.get();
135            let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len());
136            let mut i = 0;
137            for cur_char in s.chars() {
138                let size = cur_char.len_utf8();
139                let bytes = s[i..i + size].as_bytes();
140                i += size;
141                transformations.extend(
142                    bytes
143                        .iter()
144                        .enumerate()
145                        .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))),
146                );
147            }
148            normalized.transform(transformations, 0);
149            Ok(())
150        })
151    }
152}
153
154/// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their
155/// unicode counterpart, before merging everything back into a single String.
156/// This decoder will consume the tokens and merge them in one step to alleviate
157/// the fact that single token decoded might be a byte not representable as
158/// as String.
159impl Decoder for ByteLevel {
160    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
161        let toks = tokens
162            .into_iter()
163            .flat_map(|t| {
164                t.chars()
165                    .try_fold(vec![], |mut acc, c| {
166                        CHAR_BYTES.get(&c).map(|b| {
167                            acc.push(*b);
168                            acc
169                        })
170                    })
171                    .unwrap_or_else(|| t.as_bytes().to_vec())
172            })
173            .collect::<Vec<u8>>();
174        Ok(vec![String::from_utf8_lossy(&toks).to_string()])
175    }
176}
177
178/// As a `PostProcessor`, `ByteLevel` is in charge of trimming the offsets if necessary.
179impl PostProcessor for ByteLevel {
180    fn added_tokens(&self, _is_pair: bool) -> usize {
181        0
182    }
183
184    fn process_encodings(
185        &self,
186        mut encodings: Vec<Encoding>,
187        _add_special_tokens: bool,
188    ) -> Result<Vec<Encoding>> {
189        if self.trim_offsets {
190            for encoding in encodings.iter_mut() {
191                process_offsets(encoding, self.add_prefix_space);
192                encoding
193                    .get_overflowing_mut()
194                    .iter_mut()
195                    .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
196            }
197        }
198        for (i, encoding) in encodings.iter_mut().enumerate() {
199            encoding.set_sequence_id(i);
200        }
201        Ok(encodings)
202        //<dyn PostProcessor>::default_process(encodings, add_special_tokens)
203    }
204}
205
206pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) {
207    encoding.process_tokens_with_offsets_mut(|(i, (token, offsets))| {
208        let mut leading_spaces = token
209            .chars()
210            .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace())
211            .count();
212        let trailing_spaces = token
213            .chars()
214            .rev()
215            .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace())
216            .count();
217
218        if leading_spaces > 0 || trailing_spaces > 0 {
219            if leading_spaces > 0 {
220                // If user uses `is_pretokenized=True` we might have
221                // offsets that might begin at the start of the string but are
222                // NOT the first token.
223                let is_first = i == 0 || offsets.0 == 0;
224                if is_first && add_prefix_space && leading_spaces == 1 {
225                    // If we are processing the first pair of offsets, with `add_prefix_space`,
226                    // then we shouldn't remove anything we added. If there are more than one
227                    // leading spaces though, it means we didn't add them, and they should be
228                    // removed.
229                    leading_spaces = 0;
230                }
231                offsets.0 = std::cmp::min(offsets.0 + leading_spaces, offsets.1);
232            }
233            if trailing_spaces > 0 && offsets.1 >= trailing_spaces {
234                offsets.1 = std::cmp::max(offsets.1 - trailing_spaces, offsets.0);
235            }
236        }
237    });
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::tokenizer::{
244        Decoder, Encoding, OffsetReferential, OffsetType, PostProcessor, PreTokenizedString,
245        PreTokenizer,
246    };
247    use std::iter::FromIterator;
248
249    #[test]
250    fn pre_tokenization() {
251        let bytelevel = ByteLevel::default().add_prefix_space(false);
252        let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
253        bytelevel.pre_tokenize(&mut pretokenized).unwrap();
254        assert_eq!(
255            pretokenized
256                .get_splits(OffsetReferential::Original, OffsetType::Byte)
257                .into_iter()
258                .map(|(s, o, _)| (s, o))
259                .collect::<Vec<_>>(),
260            vec![
261                ("Hello", (0, 5)),
262                ("Ġmy", (5, 8)),
263                ("Ġfriend", (8, 15)),
264                (",", (15, 16)),
265                ("Ġhow", (16, 20)),
266                ("Ġis", (20, 23)),
267                ("Ġyour", (23, 28)),
268                ("Ġday", (28, 32)),
269                ("Ġgoing", (32, 38)),
270                ("?", (38, 39))
271            ]
272        );
273    }
274
275    #[test]
276    fn pre_tokenization_no_regex() {
277        let bytelevel = ByteLevel::default().use_regex(false);
278        let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
279        bytelevel.pre_tokenize(&mut pretokenized).unwrap();
280        assert_eq!(
281            pretokenized
282                .get_splits(OffsetReferential::Original, OffsetType::Byte)
283                .into_iter()
284                .map(|(s, o, _)| (s, o))
285                .collect::<Vec<_>>(),
286            vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))]
287        );
288    }
289
290    #[test]
291    fn decoding() {
292        let bytelevel = ByteLevel::default().add_prefix_space(false);
293        assert_eq!(
294            bytelevel
295                .decode_chain(
296                    vec![
297                        "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing",
298                        "?"
299                    ]
300                    .into_iter()
301                    .map(|s| s.into())
302                    .collect::<Vec<String>>()
303                )
304                .unwrap(),
305            vec!["Hello my friend, how is your day going?"]
306        );
307    }
308
309    #[test]
310    fn add_prefix_space() {
311        let bytelevel = ByteLevel::default().add_prefix_space(true);
312        for s in &[
313            " Hello my friend, how is your day going?",
314            "Hello my friend, how is your day going?",
315        ] {
316            let mut pretokenized = PreTokenizedString::from(*s);
317            bytelevel.pre_tokenize(&mut pretokenized).unwrap();
318            assert_eq!(
319                pretokenized
320                    .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
321                    .into_iter()
322                    .map(|(s, o, _)| (s, o))
323                    .collect::<Vec<_>>(),
324                vec![
325                    ("ĠHello", (0, 7)),
326                    ("Ġmy", (7, 11)),
327                    ("Ġfriend", (11, 19)),
328                    (",", (19, 20)),
329                    ("Ġhow", (20, 25)),
330                    ("Ġis", (25, 29)),
331                    ("Ġyour", (29, 35)),
332                    ("Ġday", (35, 40)),
333                    ("Ġgoing", (40, 47)),
334                    ("?", (47, 48))
335                ]
336            );
337        }
338    }
339
340    #[test]
341    fn decode_works_on_separated_tokens() {
342        let samples = vec![
343            "A Nuskhuri abbreviation of იესუ ქრისტე ( iesu kriste ) \" Jesus Christ \"",
344            "An equal number have descenders , like p or q in English \
345                 : გ , დ , ე , ვ , კ , ლ , ჟ , ტ , უ , ფ , ღ , ყ , ც",
346        ];
347
348        let bytelevel = ByteLevel::default().add_prefix_space(false);
349        for sample in samples {
350            let mut pretokenized = PreTokenizedString::from(sample);
351            bytelevel.pre_tokenize(&mut pretokenized).unwrap();
352            let separated_tokens = pretokenized
353                .get_splits(OffsetReferential::Original, OffsetType::Byte)
354                .iter()
355                .flat_map(|(s, _, _)| s.split("").map(|t| t.into()))
356                .collect::<Vec<_>>();
357            assert_eq!(
358                sample,
359                bytelevel.decode_chain(separated_tokens).unwrap().join("")
360            );
361        }
362    }
363
364    #[test]
365    fn handling_of_newlines() {
366        let mut pretokenized = PreTokenizedString::from("Hello there\nHello there");
367        let bytelevel = ByteLevel::default().add_prefix_space(false);
368        bytelevel.pre_tokenize(&mut pretokenized).unwrap();
369
370        assert_eq!(
371            pretokenized
372                .get_splits(OffsetReferential::Original, OffsetType::Byte)
373                .into_iter()
374                .map(|(s, o, _)| (s, o))
375                .collect::<Vec<_>>(),
376            vec![
377                ("Hello", (0, 5)),
378                ("Ġthere", (5, 11)),
379                ("Ċ", (11, 12)),
380                ("Hello", (12, 17)),
381                ("Ġthere", (17, 23))
382            ]
383        );
384    }
385
386    #[test]
387    fn handling_of_multiple_whitespaces() {
388        let mut pretokenized = PreTokenizedString::from("Hello there       dear");
389        let bytelevel = ByteLevel::default().add_prefix_space(false);
390        bytelevel.pre_tokenize(&mut pretokenized).unwrap();
391
392        assert_eq!(
393            pretokenized
394                .get_splits(OffsetReferential::Original, OffsetType::Byte)
395                .into_iter()
396                .map(|(s, o, _)| (s, o))
397                .collect::<Vec<_>>(),
398            vec![
399                ("Hello", (0, 5)),
400                ("Ġthere", (5, 11)),
401                ("ĠĠĠĠĠĠ", (11, 17)),
402                ("Ġdear", (17, 22))
403            ]
404        );
405    }
406
407    #[test]
408    fn offsets_when_char_split_up() {
409        let input = "i⭢j";
410        let mut pretokenized = PreTokenizedString::from(input);
411        let bytelevel = ByteLevel::default().add_prefix_space(false);
412        bytelevel.pre_tokenize(&mut pretokenized).unwrap();
413
414        assert_eq!(
415            pretokenized
416                .get_splits(OffsetReferential::Original, OffsetType::Byte)
417                .into_iter()
418                .map(|(s, o, _)| (s, o))
419                .collect::<Vec<_>>(),
420            vec![("i", (0, 1)), ("âŃ¢", (1, 4)), ("j", (4, 5))]
421        );
422        assert_eq!(
423            pretokenized
424                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
425                .into_iter()
426                .map(|(s, o, _)| (s, o))
427                .collect::<Vec<_>>(),
428            vec![("i", (0, 1)), ("âŃ¢", (1, 7)), ("j", (7, 8))]
429        );
430        assert_eq!(
431            pretokenized
432                .get_splits(OffsetReferential::Original, OffsetType::Byte)
433                .into_iter()
434                .map(|(_, o, _)| &input[o.0..o.1])
435                .collect::<Vec<_>>(),
436            vec!["i", "⭢", "j"]
437        );
438    }
439
440    #[test]
441    fn processor_trims_offsets_pre_tokenized() {
442        // If user uses `is_pretokenized=True` we might have
443        // offsets that might begin at the start of the string but are
444        // NOT the first token.
445        let mut encoding = Encoding::new(
446            vec![0; 5],
447            vec![],
448            vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()],
449            vec![],
450            vec![(0, 1), (1, 4), (0, 1), (1, 4)],
451            vec![],
452            vec![],
453            vec![],
454            HashMap::new(),
455        );
456        process_offsets(&mut encoding, true);
457        assert_eq!(
458            encoding,
459            Encoding::new(
460                vec![0; 5],
461                vec![],
462                vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()],
463                vec![],
464                vec![(0, 1), (1, 4), (0, 1), (1, 4)],
465                vec![],
466                vec![],
467                vec![],
468                HashMap::new(),
469            )
470        );
471    }
472
473    #[test]
474    fn processor_trims_offsets() {
475        let start = Encoding::new(
476            vec![0; 5],
477            vec![],
478            vec![
479                "Ġ".into(),
480                "ĠĠĠĠHelloĠĠ".into(),
481                "ĠĠHello".into(),
482                "HelloĠĠ".into(),
483                "ĠĠĠĠ".into(),
484            ],
485            vec![],
486            vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
487            vec![],
488            vec![],
489            vec![],
490            HashMap::new(),
491        );
492        let expected = Encoding::new(
493            vec![0; 5],
494            vec![0; 5],
495            vec![
496                "Ġ".into(),
497                "ĠĠĠĠHelloĠĠ".into(),
498                "ĠĠHello".into(),
499                "HelloĠĠ".into(),
500                "ĠĠĠĠ".into(),
501            ],
502            vec![],
503            vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
504            vec![],
505            vec![],
506            vec![],
507            HashMap::from_iter(vec![(0, 0..5)]),
508        );
509
510        let bytelevel = ByteLevel::default().trim_offsets(true);
511        assert_eq!(
512            expected,
513            bytelevel.process(start.clone(), None, false).unwrap()
514        );
515
516        let pair_expected = Encoding::new(
517            vec![0; 10],
518            vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
519            vec![
520                "Ġ".into(),
521                "ĠĠĠĠHelloĠĠ".into(),
522                "ĠĠHello".into(),
523                "HelloĠĠ".into(),
524                "ĠĠĠĠ".into(),
525                "Ġ".into(),
526                "ĠĠĠĠHelloĠĠ".into(),
527                "ĠĠHello".into(),
528                "HelloĠĠ".into(),
529                "ĠĠĠĠ".into(),
530            ],
531            vec![],
532            vec![
533                (0, 0),
534                (4, 9),
535                (13, 18),
536                (18, 23),
537                (29, 29),
538                (0, 0),
539                (4, 9),
540                (13, 18),
541                (18, 23),
542                (29, 29),
543            ],
544            vec![],
545            vec![],
546            vec![],
547            HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
548        );
549        assert_eq!(
550            pair_expected,
551            bytelevel
552                .process(start.clone(), Some(start), false)
553                .unwrap()
554        );
555    }
556
557    #[test]
558    fn decode_unknown_characters() {
559        let byte_level = ByteLevel::default();
560        assert_eq!(
561            byte_level
562                .decode_chain(vec![
563                    "Hello".into(),
564                    "Ġthere".into(),
565                    "Ġdear".into(),
566                    "Ġfriend!".into(),
567                    "Ġ".into(),
568                    "[PA D]".into()
569                ])
570                .unwrap(),
571            vec!["Hello there dear friend! [PA D]"]
572        );
573    }
574
575    #[test]
576    fn deserialization() {
577        // Before use_regex
578        let byte_level: ByteLevel = serde_json::from_str(
579            r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#,
580        )
581        .unwrap();
582        assert!(byte_level.use_regex);
583
584        // Loading works, new future BC test.
585        let byte_level: ByteLevel = serde_json::from_str(
586            r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#,
587        )
588        .unwrap();
589        assert!(byte_level.use_regex);
590
591        let byte_level: ByteLevel = serde_json::from_str(
592            r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#,
593        )
594        .unwrap();
595        assert!(!byte_level.use_regex);
596    }
597}