tokenizers/normalizers/
bert.rs

1use crate::tokenizer::{NormalizedString, Normalizer, Result};
2
3use serde::{Deserialize, Serialize};
4use unicode_categories::UnicodeCategories;
5
6/// Checks whether a character is whitespace
7fn is_whitespace(c: char) -> bool {
8    // These are technically control characters but we count them as whitespace
9    match c {
10        '\t' | '\n' | '\r' => true,
11        _ => c.is_whitespace(),
12    }
13}
14
15/// Checks whether a character is a control character
16fn is_control(c: char) -> bool {
17    // These are technically control characters but we count them as whitespace
18    match c {
19        '\t' | '\n' | '\r' => false,
20        // The definition of `is_control` here is quite large and contains also
21        // Cc, Cf, Cn or Co
22        // cf. https://unicode.org/reports/tr44/ (Table 12)
23        _ => c.is_other(),
24    }
25}
26
27/// Checks whether a character is chinese
28/// This defines a "chinese character" as anything in the CJK Unicode block:
29///   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
30///
31/// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
32/// despite its name. The modern Korean Hangul alphabet is a different block,
33/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
34/// space-separated words, so they are not treated specially and handled
35/// like for all of the other languages.
36fn is_chinese_char(c: char) -> bool {
37    matches!(
38        c as usize,
39        0x4E00..=0x9FFF |
40        0x3400..=0x4DBF |
41        0x20000..=0x2A6DF |
42        0x2A700..=0x2B73F |
43        0x2B740..=0x2B81F |
44        0x2B920..=0x2CEAF |
45        0xF900..=0xFAFF |
46        0x2F800..=0x2FA1F
47    )
48}
49
50#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
51#[serde(tag = "type")]
52#[non_exhaustive]
53pub struct BertNormalizer {
54    /// Whether to do the bert basic cleaning:
55    ///   1. Remove any control characters
56    ///   2. Replace all sorts of whitespace by the classic one ` `
57    pub clean_text: bool,
58    /// Whether to put spaces around chinese characters so they get split
59    pub handle_chinese_chars: bool,
60    /// Whether to strip accents
61    pub strip_accents: Option<bool>,
62    /// Whether to lowercase the input
63    pub lowercase: bool,
64}
65
66impl Default for BertNormalizer {
67    fn default() -> Self {
68        Self {
69            clean_text: true,
70            handle_chinese_chars: true,
71            strip_accents: None,
72            lowercase: true,
73        }
74    }
75}
76
77impl BertNormalizer {
78    pub fn new(
79        clean_text: bool,
80        handle_chinese_chars: bool,
81        strip_accents: Option<bool>,
82        lowercase: bool,
83    ) -> Self {
84        Self {
85            clean_text,
86            handle_chinese_chars,
87            strip_accents,
88            lowercase,
89        }
90    }
91
92    fn do_clean_text(&self, normalized: &mut NormalizedString) {
93        normalized
94            .filter(|c| !(c as usize == 0 || c as usize == 0xfffd || is_control(c)))
95            .map(|c| if is_whitespace(c) { ' ' } else { c });
96    }
97
98    fn do_handle_chinese_chars(&self, normalized: &mut NormalizedString) {
99        let mut new_chars: Vec<(char, isize)> = vec![];
100        normalized.for_each(|c| {
101            if is_chinese_char(c) {
102                new_chars.extend([(' ', 0), (c, 1), (' ', 1)]);
103            } else {
104                new_chars.push((c, 0));
105            }
106        });
107        normalized.transform(new_chars, 0);
108    }
109
110    fn do_strip_accents(&self, normalized: &mut NormalizedString) {
111        normalized.nfd().filter(|c| !c.is_mark_nonspacing());
112    }
113
114    fn do_lowercase(&self, normalized: &mut NormalizedString) {
115        normalized.lowercase();
116    }
117}
118
119impl Normalizer for BertNormalizer {
120    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
121        if self.clean_text {
122            self.do_clean_text(normalized);
123        }
124        if self.handle_chinese_chars {
125            self.do_handle_chinese_chars(normalized);
126        }
127        let strip_accents = self.strip_accents.unwrap_or(self.lowercase);
128        if strip_accents {
129            self.do_strip_accents(normalized);
130        }
131        if self.lowercase {
132            self.do_lowercase(normalized);
133        }
134
135        Ok(())
136    }
137}