Skip to main content

sentencepiece_rs/
normalizer.rs

1use crate::darts::DoubleArray;
2use crate::proto::{NormalizerSpec, TrainerSpec};
3use crate::util::{SPACE_SYMBOL, first_char_len};
4use crate::{Error, Result};
5
6/// SentencePiece-compatible normalizer.
7///
8/// It supports the model flags used by the runtime and the precompiled
9/// normalization trie embedded in standard SentencePiece model files.
10#[derive(Clone, Debug)]
11pub struct Normalizer {
12    spec: NormalizerSpec,
13    treat_whitespace_as_suffix: bool,
14    charsmap: Option<PrecompiledCharsMap>,
15    user_symbols: Vec<String>,
16}
17
18#[derive(Clone, Debug)]
19struct PrecompiledCharsMap {
20    trie: DoubleArray,
21    normalized: Vec<u8>,
22}
23
24impl Normalizer {
25    pub(crate) fn new(spec: NormalizerSpec, trainer_spec: &TrainerSpec) -> Result<Self> {
26        let charsmap = if spec.precompiled_charsmap.is_empty() {
27            None
28        } else {
29            Some(PrecompiledCharsMap::decode(&spec.precompiled_charsmap)?)
30        };
31
32        Ok(Self {
33            spec,
34            treat_whitespace_as_suffix: trainer_spec.treat_whitespace_as_suffix,
35            charsmap,
36            user_symbols: Vec::new(),
37        })
38    }
39
40    pub(crate) fn new_denormalizer(spec: NormalizerSpec) -> Result<Self> {
41        let charsmap = if spec.precompiled_charsmap.is_empty() {
42            None
43        } else {
44            Some(PrecompiledCharsMap::decode(&spec.precompiled_charsmap)?)
45        };
46
47        Ok(Self {
48            spec,
49            treat_whitespace_as_suffix: false,
50            charsmap,
51            user_symbols: Vec::new(),
52        })
53    }
54
55    pub(crate) fn set_user_symbols(&mut self, mut user_symbols: Vec<String>) {
56        user_symbols.sort_by_key(|symbol| std::cmp::Reverse(symbol.len()));
57        self.user_symbols = user_symbols;
58    }
59
60    /// Normalizes a UTF-8 string with the model's SentencePiece rules.
61    pub fn normalize(&self, input: &str) -> Result<String> {
62        if input.is_empty() {
63            return Ok(String::new());
64        }
65
66        let mut cursor = 0;
67        let bytes = input.as_bytes();
68
69        if self.spec.remove_extra_whitespaces {
70            while cursor < bytes.len() {
71                let (normalized, consumed) = self.normalize_prefix(&input[cursor..]);
72                if normalized.as_slice() != b" " {
73                    break;
74                }
75                cursor += consumed;
76            }
77        }
78
79        if cursor == bytes.len() {
80            return Ok(String::new());
81        }
82
83        let mut output = Vec::with_capacity((bytes.len() - cursor) * 3);
84        let add_ws = |output: &mut Vec<u8>| {
85            if self.spec.escape_whitespaces {
86                output.extend_from_slice(SPACE_SYMBOL.as_bytes());
87            } else {
88                output.push(b' ');
89            }
90        };
91
92        if !self.treat_whitespace_as_suffix && self.spec.add_dummy_prefix {
93            add_ws(&mut output);
94        }
95
96        let mut is_prev_space = self.spec.remove_extra_whitespaces;
97        while cursor < bytes.len() {
98            let (mut normalized, consumed) = self.normalize_prefix(&input[cursor..]);
99
100            while is_prev_space && normalized.first() == Some(&b' ') {
101                normalized.remove(0);
102            }
103
104            if !normalized.is_empty() {
105                for byte in normalized.iter().copied() {
106                    if self.spec.escape_whitespaces && byte == b' ' {
107                        output.extend_from_slice(SPACE_SYMBOL.as_bytes());
108                    } else {
109                        output.push(byte);
110                    }
111                }
112                is_prev_space = normalized.last() == Some(&b' ');
113            }
114
115            cursor += consumed;
116            if !self.spec.remove_extra_whitespaces {
117                is_prev_space = false;
118            }
119        }
120
121        if self.spec.remove_extra_whitespaces {
122            let suffix = if self.spec.escape_whitespaces {
123                SPACE_SYMBOL.as_bytes()
124            } else {
125                b" "
126            };
127            while output.ends_with(suffix) {
128                let new_len = output.len() - suffix.len();
129                output.truncate(new_len);
130            }
131        }
132
133        if self.treat_whitespace_as_suffix && self.spec.add_dummy_prefix {
134            add_ws(&mut output);
135        }
136
137        String::from_utf8(output)
138            .map_err(|_| Error::model_parse("normalization produced invalid UTF-8"))
139    }
140
141    pub(crate) fn add_dummy_prefix(&self) -> bool {
142        self.spec.add_dummy_prefix
143    }
144
145    pub(crate) fn remove_extra_whitespaces(&self) -> bool {
146        self.spec.remove_extra_whitespaces
147    }
148
149    fn normalize_prefix(&self, input: &str) -> (Vec<u8>, usize) {
150        if input.is_empty() {
151            return (Vec::new(), 0);
152        }
153
154        if let Some(symbol) = self
155            .user_symbols
156            .iter()
157            .find(|symbol| input.as_bytes().starts_with(symbol.as_bytes()))
158        {
159            return (symbol.as_bytes().to_vec(), symbol.len());
160        }
161
162        if let Some(charsmap) = &self.charsmap
163            && let Some((offset, length)) = charsmap.longest_match(input.as_bytes())
164            && let Some(normalized) = charsmap.normalized_at(offset)
165        {
166            return (normalized.to_vec(), length);
167        }
168
169        let len = first_char_len(input);
170        (input.as_bytes()[..len].to_vec(), len)
171    }
172}
173
174impl PrecompiledCharsMap {
175    fn decode(blob: &[u8]) -> Result<Self> {
176        if blob.len() <= 4 {
177            return Err(Error::model_parse("normalization rule blob is broken"));
178        }
179
180        let trie_blob_size = u32::from_le_bytes([blob[0], blob[1], blob[2], blob[3]]) as usize;
181        if trie_blob_size >= blob.len() {
182            return Err(Error::model_parse(
183                "normalization trie data exceeds the input blob size",
184            ));
185        }
186        if trie_blob_size < 1024 || (trie_blob_size & 0x3ff) != 0 {
187            return Err(Error::model_parse(
188                "normalization trie data size is not divisible by 1024",
189            ));
190        }
191
192        let trie_start = 4;
193        let trie_end = trie_start + trie_blob_size;
194        let normalized = blob[trie_end..].to_vec();
195        if normalized.is_empty() || normalized.last() != Some(&0) {
196            return Err(Error::model_parse(
197                "normalization data block must be null-terminated",
198            ));
199        }
200
201        Ok(Self {
202            trie: DoubleArray::from_le_blob(&blob[trie_start..trie_end])?,
203            normalized,
204        })
205    }
206
207    fn longest_match(&self, input: &[u8]) -> Option<(usize, usize)> {
208        self.trie
209            .common_prefix_search(input)
210            .into_iter()
211            .max_by_key(|(_, length)| *length)
212            .map(|(offset, length)| (offset as usize, length))
213    }
214
215    fn normalized_at(&self, offset: usize) -> Option<&[u8]> {
216        if offset >= self.normalized.len() {
217            return None;
218        }
219        let tail = &self.normalized[offset..];
220        let end = tail.iter().position(|byte| *byte == 0)?;
221        Some(&tail[..end])
222    }
223}