1use super::{
2 lattice::Lattice,
3 trainer::UnigramTrainer,
4 trie::{Trie, TrieBuilder},
5};
6use crate::tokenizer::{Model, Result, Token};
7use crate::utils::cache::{Cache, MAX_LENGTH};
8
9use std::collections::HashMap;
10use std::convert::TryInto;
11use std::fs::read_to_string;
12use std::path::{Path, PathBuf};
13
14type TokenMap = HashMap<String, u32>;
15type Vocab = Vec<(String, f64)>;
16
17pub struct Unigram {
19 token_to_ids: TokenMap,
20 pub(crate) vocab: Vocab,
21 cache: Cache<String, Vec<String>>,
22 trie: Trie<u8>,
23 pub min_score: f64,
24 pub(super) unk_id: Option<usize>,
25 pub(super) bos_id: usize,
26 pub(super) eos_id: usize,
27
28 fuse_unk: bool,
29 is_optimized: bool,
30 byte_fallback: bool,
31}
32impl PartialEq for Unigram {
33 fn eq(&self, other: &Self) -> bool {
34 self.unk_id == other.unk_id && self.vocab == other.vocab
35 }
36}
37
38impl Clone for Unigram {
39 fn clone(&self) -> Self {
42 let fresh_cache = self.cache.fresh();
43 Self {
44 vocab: self.vocab.clone(),
45 cache: fresh_cache,
46 token_to_ids: self.token_to_ids.clone(),
47 trie: self.trie.clone(),
48 min_score: self.min_score,
49 unk_id: self.unk_id,
50 bos_id: self.bos_id,
51 eos_id: self.eos_id,
52 fuse_unk: self.fuse_unk,
53 is_optimized: self.is_optimized,
54 byte_fallback: self.byte_fallback,
55 }
56 }
57}
58
59impl std::fmt::Debug for Unigram {
60 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
61 fmt.debug_struct("Unigram")
62 .field("vocab", &self.vocab.len())
63 .field("unk_id", &self.unk_id)
64 .field("byte_fallback", &self.byte_fallback)
65 .finish()
66 }
67}
68
69static K_UNK_PENALTY: f64 = 10.0;
70
71#[derive(thiserror::Error, Debug)]
72pub enum UnigramError {
73 #[error("The vocabulary is empty but at least <unk> is needed")]
74 EmptyVocabulary,
75 #[error("The `unk_id` is larger than vocabulary size")]
76 UnkIdNotInVocabulary,
77 #[error("Encountered an unknown token but `unk_id` is missing")]
78 MissingUnkId,
79}
80
81impl Default for Unigram {
82 fn default() -> Self {
83 let vocab = vec![("<unk>".to_string(), 0.0)];
84 Self::from(vocab, Some(0), false).unwrap()
85 }
86}
87
88impl Unigram {
89 pub fn from(
96 vocab: Vec<(String, f64)>,
97 unk_id: Option<usize>,
98 byte_fallback: bool,
99 ) -> Result<Self> {
100 let n = vocab.len();
101 let mut token_to_ids: TokenMap = HashMap::new();
102 let mut builder = TrieBuilder::default();
103
104 if let Some(unk_id) = unk_id {
105 if vocab.is_empty() {
106 return Err(Box::new(UnigramError::EmptyVocabulary));
107 }
108 if unk_id >= vocab.len() {
109 return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
110 }
111 }
112 let bos_id = n + 1;
113 let eos_id = n + 2;
114
115 let mut min_score = f64::INFINITY;
116 for (id, (token, score)) in vocab.iter().enumerate() {
117 token_to_ids.insert(token.to_string(), id as u32);
118 let bytes: Vec<u8> = token.bytes().collect();
119 builder.push(&bytes);
120 if score < &min_score {
121 min_score = *score;
122 }
123 }
124 let trie = builder.build();
125 let fuse_unk = true;
126 let is_optimized = true;
127
128 Ok(Self {
129 vocab,
130 token_to_ids,
131 trie,
132 min_score,
133 bos_id,
134 eos_id,
135 unk_id,
136 fuse_unk,
137 cache: Cache::default(),
138 is_optimized,
139 byte_fallback,
140 })
141 }
142
143 #[cfg(test)]
144 pub(super) fn set_fuse_unk(&mut self, fuse_unk: bool) {
145 self.fuse_unk = fuse_unk;
146 self.cache = self.cache.fresh();
147 }
148
149 #[cfg(test)]
150 pub(super) fn set_optimized(&mut self, is_optimized: bool) {
151 self.is_optimized = is_optimized;
152 }
153 pub fn byte_fallback(&self) -> bool {
154 self.byte_fallback
155 }
156 pub(super) fn len(&self) -> usize {
157 self.vocab.len()
158 }
159
160 pub(super) fn populate_nodes(&self, lattice: &mut Lattice) {
161 let unk_score = self.min_score - K_UNK_PENALTY;
162
163 let len = lattice.len();
164
165 let mut begin_pos = 0;
166 while begin_pos < len {
167 let mblen = lattice.sentence[begin_pos..]
168 .chars()
169 .next()
170 .unwrap()
171 .len_utf8();
172
173 let mut has_single_node = false;
174
175 for bytes in self
176 .trie
177 .common_prefix_search(lattice.sentence.bytes().skip(begin_pos))
178 {
179 let n = bytes.len();
180 let tok = String::from_utf8(bytes).unwrap();
181 let id = *self.token_to_ids.get(&tok).unwrap();
182
183 let item = &self.vocab[id as usize];
184 assert_eq!(item.0, tok);
185 let score: f64 = item.1;
186 lattice.insert(begin_pos, n, score, id.try_into().unwrap());
187 if !has_single_node && n == mblen {
188 has_single_node = true;
189 }
190 }
191
192 if !has_single_node {
193 if let Some(unk_id) = self.unk_id {
194 lattice.insert(begin_pos, mblen, unk_score, unk_id);
195 }
196 }
197 begin_pos += mblen
198 }
199 }
200
201 pub fn encode(&self, sentence: &str) -> Result<Vec<String>> {
222 if sentence.is_empty() {
223 return Ok(vec![]);
224 }
225 if let Some(result) = self.cache.get(sentence) {
226 Ok(result.to_vec())
227 } else {
228 let result = if self.is_optimized {
229 self.encode_optimized(sentence)?
230 } else {
231 self.encode_unoptimized(sentence)?
232 };
233 if sentence.len() < MAX_LENGTH {
234 self.cache.set(sentence.to_owned(), result.clone());
235 }
236 Ok(result)
237 }
238 }
239
240 fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
241 #[derive(Debug, Clone)]
243 struct BestPathNode {
244 id: usize,
246 best_path_score: f64,
248 starts_at: Option<usize>,
251 }
252 impl Default for BestPathNode {
253 fn default() -> Self {
254 Self {
255 id: 0,
256 best_path_score: 0.0,
257 starts_at: None,
258 }
259 }
260 }
261 let size = sentence.len();
262 let unk_score = self.min_score - K_UNK_PENALTY;
263
264 let mut best_path_ends_at = vec![BestPathNode::default(); size + 1];
265 let mut starts_at = 0;
266 while starts_at < size {
267 let best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
268 let mut has_single_node = false;
269 let mblen = sentence[starts_at..].chars().next().unwrap().len_utf8();
270 for tok_bytes in self
271 .trie
272 .common_prefix_search(sentence.bytes().skip(starts_at))
273 {
274 let key_pos = starts_at + tok_bytes.len();
275 let token: String = String::from_utf8(tok_bytes).unwrap();
276 let target_node = &mut best_path_ends_at[key_pos];
277 let length = key_pos - starts_at;
278 let id = self.token_to_ids.get(&token).unwrap();
279 let score = self.vocab.get(*id as usize).unwrap().1;
280 let candidate_best_path_score = score + best_path_score_till_here;
281 if target_node.starts_at.is_none()
282 || candidate_best_path_score > target_node.best_path_score
283 {
284 target_node.best_path_score = candidate_best_path_score;
285 target_node.starts_at = Some(starts_at);
286 target_node.id = *id as usize;
287 }
288 if !has_single_node && length == mblen {
289 has_single_node = true;
290 }
291 }
292 if !has_single_node {
293 let target_node = &mut best_path_ends_at[starts_at + mblen];
294 let candidate_best_path_score = unk_score + best_path_score_till_here;
295 if target_node.starts_at.is_none()
296 || candidate_best_path_score > target_node.best_path_score
297 {
298 target_node.best_path_score = candidate_best_path_score;
299 target_node.starts_at = Some(starts_at);
300 target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
301 }
302 }
303 starts_at += mblen
304 }
305 let mut ends_at = size;
306 let mut results: Vec<String> = vec![];
307 let mut token = vec![];
308 while ends_at > 0 {
309 let node = &best_path_ends_at[ends_at];
310 let starts_at = node.starts_at.unwrap();
311 if self.fuse_unk
312 && self.unk_id.is_some()
313 && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)?
314 {
315 token.push(
316 String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
317 );
318 } else {
319 if !token.is_empty() {
320 token.reverse();
321 results.push(token.concat());
322 token = vec![];
323 }
324 results.push(
325 String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
326 );
327 }
328 ends_at = starts_at;
329 }
330 if !token.is_empty() {
331 token.reverse();
332 results.push(token.concat());
333 }
334 results.reverse();
335 Ok(results)
336 }
337
338 fn encode_unoptimized(&self, sentence: &str) -> Result<Vec<String>> {
339 let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id);
340 self.populate_nodes(&mut lattice);
341 if self.fuse_unk {
342 let mut results = vec![];
343 let mut token = String::new();
344 for node in lattice.viterbi().iter() {
345 let item = lattice.piece(&node.borrow());
346 if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
347 token.push_str(&item);
348 } else {
349 if !token.is_empty() {
350 results.push(token);
351 token = String::new();
352 }
353 results.push(item.to_string());
354 }
355 }
356 if !token.is_empty() {
357 results.push(token);
358 }
359 Ok(results)
360 } else {
361 Ok(lattice.tokens())
362 }
363 }
364
365 pub fn iter(&self) -> UnigramIterator {
367 UnigramIterator { model: self, i: 0 }
368 }
369
370 pub fn load<P: AsRef<Path>>(path: P) -> Result<Unigram> {
379 let string = read_to_string(path)?;
380 Ok(serde_json::from_str(&string)?)
381 }
382
383 pub fn clear_cache(&mut self) {
385 self.cache.clear();
386 }
387
388 pub fn resize_cache(&mut self, capacity: usize) {
390 self.cache.resize(capacity);
391 }
392}
393
394pub struct UnigramIterator<'a> {
396 model: &'a Unigram,
397 i: usize,
398}
399
400impl<'a> Iterator for UnigramIterator<'a> {
401 type Item = &'a (String, f64);
402
403 fn next(&mut self) -> Option<Self::Item> {
404 let i = self.i;
405 if i < self.model.len() {
406 let r = Some(&self.model.vocab[i]);
407 self.i += 1;
408 r
409 } else {
410 None
411 }
412 }
413}
414
415impl Model for Unigram {
416 type Trainer = UnigramTrainer;
417
418 fn get_vocab(&self) -> HashMap<String, u32> {
419 self.token_to_ids.clone()
420 }
421
422 fn get_vocab_size(&self) -> usize {
423 self.vocab.len()
424 }
425
426 fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
427 let str_tokens = self.encode(sentence)?;
428 let mut offset = 0;
429 let mut tokens = Vec::with_capacity(str_tokens.len());
430 for string in str_tokens {
431 let len = string.len();
432 let offsets = (offset, offset + len);
433 let id: u32 = match self.token_to_ids.get(&string) {
434 Some(id) => *id,
435 None => {
436 if self.byte_fallback {
437 let byte_tokens: Option<Vec<_>> = string
438 .bytes()
439 .map(|byte| -> Option<Token> {
440 let byte_string = format!("<0x{byte:02X}>");
441 let id = self.token_to_ids.get(&byte_string);
442 id.map(|id| Token::new(*id, byte_string, (offset, offset + len)))
443 })
444 .collect();
445 if let Some(byte_tokens) = byte_tokens {
446 for token in byte_tokens {
447 tokens.push(token);
448 }
449 offset += len;
450 continue;
451 }
452 }
453 self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32
454 }
455 };
456 offset += len;
457 tokens.push(Token::new(id, string, offsets));
458 }
459 Ok(tokens)
460 }
461
462 fn token_to_id(&self, token: &str) -> Option<u32> {
463 self.token_to_ids.get(token).copied()
464 }
465
466 fn id_to_token(&self, id: u32) -> Option<String> {
467 self.vocab.get(id as usize).map(|item| item.0.clone())
468 }
469
470 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
471 let name = match name {
472 Some(name) => format!("{name}-unigram.json"),
473 None => "unigram.json".to_string(),
474 };
475 let mut fullpath = PathBuf::new();
476 fullpath.push(folder);
477 fullpath.push(name);
478 let string = serde_json::to_string_pretty(self)?;
479 std::fs::write(&fullpath, string)?;
480 Ok(vec![fullpath])
481 }
482
483 fn get_trainer(&self) -> Self::Trainer {
484 UnigramTrainer::default()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_populate_nodes_unk() {
494 let pieces = vec![("<unk>".to_string(), 0.0)];
495 let model = Unigram::from(pieces, Some(0), false).unwrap();
496
497 let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
498 model.populate_nodes(&mut lattice);
499
500 assert_eq!(lattice.begin_nodes[0].len(), 1);
501 assert_eq!(lattice.begin_nodes[1].len(), 1);
502 assert_eq!(lattice.begin_nodes[2].len(), 1);
503 assert_eq!(lattice.begin_nodes[0][0].borrow().id, 0);
504 assert_eq!(lattice.begin_nodes[1][0].borrow().id, 0);
505 assert_eq!(lattice.begin_nodes[2][0].borrow().id, 0);
506 assert_eq!(lattice.begin_nodes[0][0].borrow().node_id, 2);
507 assert_eq!(lattice.begin_nodes[1][0].borrow().node_id, 3);
508 assert_eq!(lattice.begin_nodes[2][0].borrow().node_id, 4);
509 }
510
511 #[test]
512 fn test_populate_nodes() {
513 let pieces = vec![
514 ("<unk>".to_string(), 0.0),
515 ("a".to_string(), 0.1),
516 ("b".to_string(), 0.2),
517 ("ab".to_string(), 0.3),
518 ("bc".to_string(), 0.4),
519 ];
520 let model = Unigram::from(pieces, Some(0), false).unwrap();
521
522 let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
523 model.populate_nodes(&mut lattice);
524
525 assert_eq!(lattice.begin_nodes[0].len(), 2); assert_eq!(lattice.begin_nodes[1].len(), 2); assert_eq!(lattice.begin_nodes[2].len(), 1); assert_eq!(lattice.begin_nodes[0][0].borrow().id, 1);
532 assert_eq!(lattice.begin_nodes[0][1].borrow().id, 3);
533 assert_eq!(lattice.begin_nodes[1][0].borrow().id, 2);
534 assert_eq!(lattice.begin_nodes[1][1].borrow().id, 4);
535 assert_eq!(lattice.begin_nodes[2][0].borrow().id, 0);
536 assert_eq!(lattice.begin_nodes[0][0].borrow().node_id, 2);
537 assert_eq!(lattice.begin_nodes[0][1].borrow().node_id, 3);
538 assert_eq!(lattice.begin_nodes[1][0].borrow().node_id, 4);
539 assert_eq!(lattice.begin_nodes[1][1].borrow().node_id, 5);
540 assert_eq!(lattice.begin_nodes[2][0].borrow().node_id, 6);
541 }
542
543 #[test]
544 fn test_encode() {
545 let sentencepieces = vec![
546 ("<unk>".to_string(), 0.0),
547 ("a".to_string(), 0.0),
548 ("b".to_string(), 0.0),
549 ("c".to_string(), 0.0),
550 ("d".to_string(), 0.0),
551 ("cd".to_string(), 1.0),
552 ("ab".to_string(), 2.0),
553 ("abc".to_string(), 5.0),
554 ("abcd".to_string(), 10.0),
555 ];
556
557 let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
558 let result = model.encode("abcd").unwrap();
559 assert_eq!(result, vec!["abcd"]);
560 }
561
562 #[test]
563 fn test_encode2() {
564 let sentencepieces = vec![
565 ("<unk>".to_string(), 0.0),
566 ("ab".to_string(), 0.0),
567 ("cd".to_string(), -0.1),
568 ("abc".to_string(), -0.2),
569 ("a".to_string(), -0.3),
570 ("b".to_string(), -0.4),
571 ("c".to_string(), -0.5),
572 ("ABC".to_string(), -0.5),
573 ("abcdabcd".to_string(), 20.0), ("q".to_string(), 20.5),
575 ("r".to_string(), 20.5),
576 ("qr".to_string(), -0.5),
577 ];
578
579 let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
580
581 for is_optimized in &[true, false] {
582 model.set_optimized(*is_optimized);
583 println!("IsOptimized {is_optimized:?}");
584 assert_eq!(model.encode("abc").unwrap(), vec!["abc"]);
585 assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
586
587 model.set_fuse_unk(false);
588 assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]);
589 model.set_fuse_unk(true);
590 assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
591
592 assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]);
593 assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]);
594 assert_eq!(
595 model.encode("xabcabaabcdd").unwrap(),
596 vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
597 );
598 model.set_fuse_unk(false);
599 assert_eq!(
600 model.encode("xyz東京").unwrap(),
601 vec!["x", "y", "z", "東", "京"]
602 );
603 model.set_fuse_unk(true);
604 assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]);
605
606 assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]);
608 assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]);
609 assert_eq!(
610 model.encode("ababcdabcdcd").unwrap(),
611 vec!["ab", "abcdabcd", "cd"]
612 );
613 assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
614 }
615 }
616
617 #[test]
618 fn test_unigram_bytefallback() {
619 let sentencepieces = vec![
622 ("<unk>".to_string(), 0.0),
623 ("<0xC3>".to_string(), -0.01),
624 ("<0xA9>".to_string(), -0.03),
625 ];
626 let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
627 let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
628 assert_eq!(
629 tokens,
630 [
631 Token {
632 id: 1,
633 value: "<0xC3>".to_string(),
634 offsets: (0, 2)
635 },
636 Token {
637 id: 2,
638 value: "<0xA9>".to_string(),
639 offsets: (0, 2)
640 }
641 ]
642 );
643
644 let tokens = unigram.tokenize("?é").unwrap();
645 assert_eq!(tokens[0].id, 0);
646 }
647}