tokenizers/models/unigram/
model.rs

1use super::{
2    lattice::Lattice,
3    trainer::UnigramTrainer,
4    trie::{Trie, TrieBuilder},
5};
6use crate::tokenizer::{Model, Result, Token};
7use crate::utils::cache::{Cache, MAX_LENGTH};
8
9use std::collections::HashMap;
10use std::convert::TryInto;
11use std::fs::read_to_string;
12use std::path::{Path, PathBuf};
13
14type TokenMap = HashMap<String, u32>;
15type Vocab = Vec<(String, f64)>;
16
17/// A `Unigram` model to encode sentences.
18pub struct Unigram {
19    token_to_ids: TokenMap,
20    pub(crate) vocab: Vocab,
21    cache: Cache<String, Vec<String>>,
22    trie: Trie<u8>,
23    pub min_score: f64,
24    pub(super) unk_id: Option<usize>,
25    pub(super) bos_id: usize,
26    pub(super) eos_id: usize,
27
28    fuse_unk: bool,
29    is_optimized: bool,
30    byte_fallback: bool,
31}
32impl PartialEq for Unigram {
33    fn eq(&self, other: &Self) -> bool {
34        self.unk_id == other.unk_id && self.vocab == other.vocab
35    }
36}
37
38impl Clone for Unigram {
39    // `Clone` can't be derive because it's not implemented for `Cache`.
40    // To keep things simple when we clone, the new Unigram will start with a fresh cache.
41    fn clone(&self) -> Self {
42        let fresh_cache = self.cache.fresh();
43        Self {
44            vocab: self.vocab.clone(),
45            cache: fresh_cache,
46            token_to_ids: self.token_to_ids.clone(),
47            trie: self.trie.clone(),
48            min_score: self.min_score,
49            unk_id: self.unk_id,
50            bos_id: self.bos_id,
51            eos_id: self.eos_id,
52            fuse_unk: self.fuse_unk,
53            is_optimized: self.is_optimized,
54            byte_fallback: self.byte_fallback,
55        }
56    }
57}
58
59impl std::fmt::Debug for Unigram {
60    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
61        fmt.debug_struct("Unigram")
62            .field("vocab", &self.vocab.len())
63            .field("unk_id", &self.unk_id)
64            .field("byte_fallback", &self.byte_fallback)
65            .finish()
66    }
67}
68
69static K_UNK_PENALTY: f64 = 10.0;
70
71#[derive(thiserror::Error, Debug)]
72pub enum UnigramError {
73    #[error("The vocabulary is empty but at least <unk> is needed")]
74    EmptyVocabulary,
75    #[error("The `unk_id` is larger than vocabulary size")]
76    UnkIdNotInVocabulary,
77    #[error("Encountered an unknown token but `unk_id` is missing")]
78    MissingUnkId,
79}
80
81impl Default for Unigram {
82    fn default() -> Self {
83        let vocab = vec![("<unk>".to_string(), 0.0)];
84        Self::from(vocab, Some(0), false).unwrap()
85    }
86}
87
88impl Unigram {
89    /// Create a `Unigram` model from a given vocabulary.
90    /// Vocabulary are the various tokens and their associated score which is a sort of a logprob of
91    /// their frequency, which will enable tokenization and sampling.
92    /// unk_id, is the index within the vocabulary.
93    /// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
94    /// Further versions might allow that part to be hidden.
95    pub fn from(
96        vocab: Vec<(String, f64)>,
97        unk_id: Option<usize>,
98        byte_fallback: bool,
99    ) -> Result<Self> {
100        let n = vocab.len();
101        let mut token_to_ids: TokenMap = HashMap::new();
102        let mut builder = TrieBuilder::default();
103
104        if let Some(unk_id) = unk_id {
105            if vocab.is_empty() {
106                return Err(Box::new(UnigramError::EmptyVocabulary));
107            }
108            if unk_id >= vocab.len() {
109                return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
110            }
111        }
112        let bos_id = n + 1;
113        let eos_id = n + 2;
114
115        let mut min_score = f64::INFINITY;
116        for (id, (token, score)) in vocab.iter().enumerate() {
117            token_to_ids.insert(token.to_string(), id as u32);
118            let bytes: Vec<u8> = token.bytes().collect();
119            builder.push(&bytes);
120            if score < &min_score {
121                min_score = *score;
122            }
123        }
124        let trie = builder.build();
125        let fuse_unk = true;
126        let is_optimized = true;
127
128        Ok(Self {
129            vocab,
130            token_to_ids,
131            trie,
132            min_score,
133            bos_id,
134            eos_id,
135            unk_id,
136            fuse_unk,
137            cache: Cache::default(),
138            is_optimized,
139            byte_fallback,
140        })
141    }
142
143    #[cfg(test)]
144    pub(super) fn set_fuse_unk(&mut self, fuse_unk: bool) {
145        self.fuse_unk = fuse_unk;
146        self.cache = self.cache.fresh();
147    }
148
149    #[cfg(test)]
150    pub(super) fn set_optimized(&mut self, is_optimized: bool) {
151        self.is_optimized = is_optimized;
152    }
153    pub fn byte_fallback(&self) -> bool {
154        self.byte_fallback
155    }
156    pub(super) fn len(&self) -> usize {
157        self.vocab.len()
158    }
159
160    pub(super) fn populate_nodes(&self, lattice: &mut Lattice) {
161        let unk_score = self.min_score - K_UNK_PENALTY;
162
163        let len = lattice.len();
164
165        let mut begin_pos = 0;
166        while begin_pos < len {
167            let mblen = lattice.sentence[begin_pos..]
168                .chars()
169                .next()
170                .unwrap()
171                .len_utf8();
172
173            let mut has_single_node = false;
174
175            for bytes in self
176                .trie
177                .common_prefix_search(lattice.sentence.bytes().skip(begin_pos))
178            {
179                let n = bytes.len();
180                let tok = String::from_utf8(bytes).unwrap();
181                let id = *self.token_to_ids.get(&tok).unwrap();
182
183                let item = &self.vocab[id as usize];
184                assert_eq!(item.0, tok);
185                let score: f64 = item.1;
186                lattice.insert(begin_pos, n, score, id.try_into().unwrap());
187                if !has_single_node && n == mblen {
188                    has_single_node = true;
189                }
190            }
191
192            if !has_single_node {
193                if let Some(unk_id) = self.unk_id {
194                    lattice.insert(begin_pos, mblen, unk_score, unk_id);
195                }
196            }
197            begin_pos += mblen
198        }
199    }
200
201    /// This functions take a String, and will encode it in a Vec of Strings,
202    /// of the best tokenization available to the current model.
203    /// ```
204    /// use tokenizers::models::unigram::Unigram;
205    ///
206    /// let pieces = vec![
207    ///     ("<unk>".to_string(), 0.0),
208    ///     ("a".to_string(), 0.0),
209    ///     ("b".to_string(), 0.0),
210    ///     ("c".to_string(), 0.0),
211    ///     ("d".to_string(), 0.0),
212    ///     ("cd".to_string(), 1.0),
213    ///     ("ab".to_string(), 2.0),
214    ///     ("abc".to_string(), 5.0),
215    ///     ("abcd".to_string(), 10.0),
216    /// ];
217    /// let model = Unigram::from(pieces, Some(0), false).unwrap();
218    /// let result = model.encode("abcdacdxx").unwrap();
219    /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
220    /// ```
221    pub fn encode(&self, sentence: &str) -> Result<Vec<String>> {
222        if sentence.is_empty() {
223            return Ok(vec![]);
224        }
225        if let Some(result) = self.cache.get(sentence) {
226            Ok(result.to_vec())
227        } else {
228            let result = if self.is_optimized {
229                self.encode_optimized(sentence)?
230            } else {
231                self.encode_unoptimized(sentence)?
232            };
233            if sentence.len() < MAX_LENGTH {
234                self.cache.set(sentence.to_owned(), result.clone());
235            }
236            Ok(result)
237        }
238    }
239
240    fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
241        // https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
242        #[derive(Debug, Clone)]
243        struct BestPathNode {
244            /// The vocab id. (maybe UNK)
245            id: usize,
246            /// The total score of the best path ending at this node.
247            best_path_score: f64,
248            /// The starting position (in utf-8) of this node. The entire best
249            /// path can be constructed by backtracking along this link.
250            starts_at: Option<usize>,
251        }
252        impl Default for BestPathNode {
253            fn default() -> Self {
254                Self {
255                    id: 0,
256                    best_path_score: 0.0,
257                    starts_at: None,
258                }
259            }
260        }
261        let size = sentence.len();
262        let unk_score = self.min_score - K_UNK_PENALTY;
263
264        let mut best_path_ends_at = vec![BestPathNode::default(); size + 1];
265        let mut starts_at = 0;
266        while starts_at < size {
267            let best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
268            let mut has_single_node = false;
269            let mblen = sentence[starts_at..].chars().next().unwrap().len_utf8();
270            for tok_bytes in self
271                .trie
272                .common_prefix_search(sentence.bytes().skip(starts_at))
273            {
274                let key_pos = starts_at + tok_bytes.len();
275                let token: String = String::from_utf8(tok_bytes).unwrap();
276                let target_node = &mut best_path_ends_at[key_pos];
277                let length = key_pos - starts_at;
278                let id = self.token_to_ids.get(&token).unwrap();
279                let score = self.vocab.get(*id as usize).unwrap().1;
280                let candidate_best_path_score = score + best_path_score_till_here;
281                if target_node.starts_at.is_none()
282                    || candidate_best_path_score > target_node.best_path_score
283                {
284                    target_node.best_path_score = candidate_best_path_score;
285                    target_node.starts_at = Some(starts_at);
286                    target_node.id = *id as usize;
287                }
288                if !has_single_node && length == mblen {
289                    has_single_node = true;
290                }
291            }
292            if !has_single_node {
293                let target_node = &mut best_path_ends_at[starts_at + mblen];
294                let candidate_best_path_score = unk_score + best_path_score_till_here;
295                if target_node.starts_at.is_none()
296                    || candidate_best_path_score > target_node.best_path_score
297                {
298                    target_node.best_path_score = candidate_best_path_score;
299                    target_node.starts_at = Some(starts_at);
300                    target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
301                }
302            }
303            starts_at += mblen
304        }
305        let mut ends_at = size;
306        let mut results: Vec<String> = vec![];
307        let mut token = vec![];
308        while ends_at > 0 {
309            let node = &best_path_ends_at[ends_at];
310            let starts_at = node.starts_at.unwrap();
311            if self.fuse_unk
312                && self.unk_id.is_some()
313                && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)?
314            {
315                token.push(
316                    String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
317                );
318            } else {
319                if !token.is_empty() {
320                    token.reverse();
321                    results.push(token.concat());
322                    token = vec![];
323                }
324                results.push(
325                    String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
326                );
327            }
328            ends_at = starts_at;
329        }
330        if !token.is_empty() {
331            token.reverse();
332            results.push(token.concat());
333        }
334        results.reverse();
335        Ok(results)
336    }
337
338    fn encode_unoptimized(&self, sentence: &str) -> Result<Vec<String>> {
339        let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id);
340        self.populate_nodes(&mut lattice);
341        if self.fuse_unk {
342            let mut results = vec![];
343            let mut token = String::new();
344            for node in lattice.viterbi().iter() {
345                let item = lattice.piece(&node.borrow());
346                if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
347                    token.push_str(&item);
348                } else {
349                    if !token.is_empty() {
350                        results.push(token);
351                        token = String::new();
352                    }
353                    results.push(item.to_string());
354                }
355            }
356            if !token.is_empty() {
357                results.push(token);
358            }
359            Ok(results)
360        } else {
361            Ok(lattice.tokens())
362        }
363    }
364
365    /// Iterate of vocabulary of the model as a pair of `(token, score)`.
366    pub fn iter(&self) -> UnigramIterator {
367        UnigramIterator { model: self, i: 0 }
368    }
369
370    /// Loads a SentencePiece output model after being trained by tokenizers.
371    /// After that you can use the model with tokenizers library.
372    /// ```no_run
373    /// use tokenizers::models::unigram::Unigram;
374    /// use std::path::Path;
375    ///
376    /// let model = Unigram::load("mymodel-unigram.json").unwrap();
377    /// ```
378    pub fn load<P: AsRef<Path>>(path: P) -> Result<Unigram> {
379        let string = read_to_string(path)?;
380        Ok(serde_json::from_str(&string)?)
381    }
382
383    /// Clears the internal cache
384    pub fn clear_cache(&mut self) {
385        self.cache.clear();
386    }
387
388    /// Resize the cache
389    pub fn resize_cache(&mut self, capacity: usize) {
390        self.cache.resize(capacity);
391    }
392}
393
394/// Iterator to iterate of vocabulary of the model, and their relative score.
395pub struct UnigramIterator<'a> {
396    model: &'a Unigram,
397    i: usize,
398}
399
400impl<'a> Iterator for UnigramIterator<'a> {
401    type Item = &'a (String, f64);
402
403    fn next(&mut self) -> Option<Self::Item> {
404        let i = self.i;
405        if i < self.model.len() {
406            let r = Some(&self.model.vocab[i]);
407            self.i += 1;
408            r
409        } else {
410            None
411        }
412    }
413}
414
415impl Model for Unigram {
416    type Trainer = UnigramTrainer;
417
418    fn get_vocab(&self) -> HashMap<String, u32> {
419        self.token_to_ids.clone()
420    }
421
422    fn get_vocab_size(&self) -> usize {
423        self.vocab.len()
424    }
425
426    fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
427        let str_tokens = self.encode(sentence)?;
428        let mut offset = 0;
429        let mut tokens = Vec::with_capacity(str_tokens.len());
430        for string in str_tokens {
431            let len = string.len();
432            let offsets = (offset, offset + len);
433            let id: u32 = match self.token_to_ids.get(&string) {
434                Some(id) => *id,
435                None => {
436                    if self.byte_fallback {
437                        let byte_tokens: Option<Vec<_>> = string
438                            .bytes()
439                            .map(|byte| -> Option<Token> {
440                                let byte_string = format!("<0x{byte:02X}>");
441                                let id = self.token_to_ids.get(&byte_string);
442                                id.map(|id| Token::new(*id, byte_string, (offset, offset + len)))
443                            })
444                            .collect();
445                        if let Some(byte_tokens) = byte_tokens {
446                            for token in byte_tokens {
447                                tokens.push(token);
448                            }
449                            offset += len;
450                            continue;
451                        }
452                    }
453                    self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32
454                }
455            };
456            offset += len;
457            tokens.push(Token::new(id, string, offsets));
458        }
459        Ok(tokens)
460    }
461
462    fn token_to_id(&self, token: &str) -> Option<u32> {
463        self.token_to_ids.get(token).copied()
464    }
465
466    fn id_to_token(&self, id: u32) -> Option<String> {
467        self.vocab.get(id as usize).map(|item| item.0.clone())
468    }
469
470    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
471        let name = match name {
472            Some(name) => format!("{name}-unigram.json"),
473            None => "unigram.json".to_string(),
474        };
475        let mut fullpath = PathBuf::new();
476        fullpath.push(folder);
477        fullpath.push(name);
478        let string = serde_json::to_string_pretty(self)?;
479        std::fs::write(&fullpath, string)?;
480        Ok(vec![fullpath])
481    }
482
483    fn get_trainer(&self) -> Self::Trainer {
484        UnigramTrainer::default()
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_populate_nodes_unk() {
494        let pieces = vec![("<unk>".to_string(), 0.0)];
495        let model = Unigram::from(pieces, Some(0), false).unwrap();
496
497        let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
498        model.populate_nodes(&mut lattice);
499
500        assert_eq!(lattice.begin_nodes[0].len(), 1);
501        assert_eq!(lattice.begin_nodes[1].len(), 1);
502        assert_eq!(lattice.begin_nodes[2].len(), 1);
503        assert_eq!(lattice.begin_nodes[0][0].borrow().id, 0);
504        assert_eq!(lattice.begin_nodes[1][0].borrow().id, 0);
505        assert_eq!(lattice.begin_nodes[2][0].borrow().id, 0);
506        assert_eq!(lattice.begin_nodes[0][0].borrow().node_id, 2);
507        assert_eq!(lattice.begin_nodes[1][0].borrow().node_id, 3);
508        assert_eq!(lattice.begin_nodes[2][0].borrow().node_id, 4);
509    }
510
511    #[test]
512    fn test_populate_nodes() {
513        let pieces = vec![
514            ("<unk>".to_string(), 0.0),
515            ("a".to_string(), 0.1),
516            ("b".to_string(), 0.2),
517            ("ab".to_string(), 0.3),
518            ("bc".to_string(), 0.4),
519        ];
520        let model = Unigram::from(pieces, Some(0), false).unwrap();
521
522        let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
523        model.populate_nodes(&mut lattice);
524
525        assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab
526        assert_eq!(lattice.begin_nodes[1].len(), 2); // b, bc
527        assert_eq!(lattice.begin_nodes[2].len(), 1); // c(unk)
528
529        // Id is the vocabulary id from Unigram model
530        // node_id is simply the rank of the given node in the lattice.
531        assert_eq!(lattice.begin_nodes[0][0].borrow().id, 1);
532        assert_eq!(lattice.begin_nodes[0][1].borrow().id, 3);
533        assert_eq!(lattice.begin_nodes[1][0].borrow().id, 2);
534        assert_eq!(lattice.begin_nodes[1][1].borrow().id, 4);
535        assert_eq!(lattice.begin_nodes[2][0].borrow().id, 0);
536        assert_eq!(lattice.begin_nodes[0][0].borrow().node_id, 2);
537        assert_eq!(lattice.begin_nodes[0][1].borrow().node_id, 3);
538        assert_eq!(lattice.begin_nodes[1][0].borrow().node_id, 4);
539        assert_eq!(lattice.begin_nodes[1][1].borrow().node_id, 5);
540        assert_eq!(lattice.begin_nodes[2][0].borrow().node_id, 6);
541    }
542
543    #[test]
544    fn test_encode() {
545        let sentencepieces = vec![
546            ("<unk>".to_string(), 0.0),
547            ("a".to_string(), 0.0),
548            ("b".to_string(), 0.0),
549            ("c".to_string(), 0.0),
550            ("d".to_string(), 0.0),
551            ("cd".to_string(), 1.0),
552            ("ab".to_string(), 2.0),
553            ("abc".to_string(), 5.0),
554            ("abcd".to_string(), 10.0),
555        ];
556
557        let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
558        let result = model.encode("abcd").unwrap();
559        assert_eq!(result, vec!["abcd"]);
560    }
561
562    #[test]
563    fn test_encode2() {
564        let sentencepieces = vec![
565            ("<unk>".to_string(), 0.0),
566            ("ab".to_string(), 0.0),
567            ("cd".to_string(), -0.1),
568            ("abc".to_string(), -0.2),
569            ("a".to_string(), -0.3),
570            ("b".to_string(), -0.4),
571            ("c".to_string(), -0.5),
572            ("ABC".to_string(), -0.5),
573            ("abcdabcd".to_string(), 20.0), // User defined just max the scores.
574            ("q".to_string(), 20.5),
575            ("r".to_string(), 20.5),
576            ("qr".to_string(), -0.5),
577        ];
578
579        let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
580
581        for is_optimized in &[true, false] {
582            model.set_optimized(*is_optimized);
583            println!("IsOptimized {is_optimized:?}");
584            assert_eq!(model.encode("abc").unwrap(), vec!["abc"]);
585            assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
586
587            model.set_fuse_unk(false);
588            assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]);
589            model.set_fuse_unk(true);
590            assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
591
592            assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]);
593            assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]);
594            assert_eq!(
595                model.encode("xabcabaabcdd").unwrap(),
596                vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
597            );
598            model.set_fuse_unk(false);
599            assert_eq!(
600                model.encode("xyz東京").unwrap(),
601                vec!["x", "y", "z", "東", "京"]
602            );
603            model.set_fuse_unk(true);
604            assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]);
605
606            // User encoded in original version
607            assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]);
608            assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]);
609            assert_eq!(
610                model.encode("ababcdabcdcd").unwrap(),
611                vec!["ab", "abcdabcd", "cd"]
612            );
613            assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
614        }
615    }
616
617    #[test]
618    fn test_unigram_bytefallback() {
619        // In [97]: processor.encode_as_pieces("⅐⅛⅑ ")
620        // Out[97]: ['▁', '<0xE2>', '<0x85>', '<0x90>', '⅛', '<0xE2>', '<0x85>', '<0x91>', '▁']
621        let sentencepieces = vec![
622            ("<unk>".to_string(), 0.0),
623            ("<0xC3>".to_string(), -0.01),
624            ("<0xA9>".to_string(), -0.03),
625        ];
626        let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
627        let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
628        assert_eq!(
629            tokens,
630            [
631                Token {
632                    id: 1,
633                    value: "<0xC3>".to_string(),
634                    offsets: (0, 2)
635                },
636                Token {
637                    id: 2,
638                    value: "<0xA9>".to_string(),
639                    offsets: (0, 2)
640                }
641            ]
642        );
643
644        let tokens = unigram.tokenize("?é").unwrap();
645        assert_eq!(tokens[0].id, 0);
646    }
647}