1use rand::distributions::WeightedIndex;
2use rand::prelude::*;
3use std::cell::RefCell;
4use std::cmp::{min, Ordering};
5use std::collections::BinaryHeap;
6use std::rc::Rc;
7
8type NodeRef = Rc<RefCell<Node>>;
9type HypothesisRef = Rc<RefCell<Hypothesis>>;
10type Agenda = BinaryHeap<Hypothesis>;
11
12struct Hypothesis {
13 node_ref: NodeRef,
14 next: Option<HypothesisRef>,
15 fx: f64,
16 gx: f64,
17}
18impl Hypothesis {
19 pub fn new(node_ref: NodeRef, next: Option<HypothesisRef>, fx: f64, gx: f64) -> Self {
20 Self {
21 node_ref,
22 next,
23 fx,
24 gx,
25 }
26 }
27}
28impl PartialEq for Hypothesis {
29 fn eq(&self, other: &Self) -> bool {
30 self.fx == other.fx
31 }
32}
33impl Eq for Hypothesis {}
34impl PartialOrd for Hypothesis {
35 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
36 Some(self.cmp(other))
37 }
38}
39impl Ord for Hypothesis {
41 fn cmp(&self, other: &Self) -> Ordering {
42 if self.fx < other.fx {
43 Ordering::Less
44 } else {
45 Ordering::Greater
46 }
47 }
48}
49
50#[derive(Debug)]
53pub struct Lattice<'a> {
54 pub(super) sentence: &'a str,
55 len: usize,
56 nodes: Vec<NodeRef>,
57 pub(super) begin_nodes: Vec<Vec<NodeRef>>,
58 pub(super) end_nodes: Vec<Vec<NodeRef>>,
59 _bos_id: usize,
60 _eos_id: usize,
61}
62
63impl std::fmt::Display for Lattice<'_> {
64 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65 let display_pieces = |nodes: &Vec<Vec<NodeRef>>| {
66 nodes
67 .iter()
68 .map(|l| {
69 l.iter()
70 .map(|n| self.piece(&n.borrow()))
71 .collect::<Vec<_>>()
72 })
73 .collect::<Vec<_>>()
74 };
75
76 f.debug_struct("Lattice")
77 .field("sentence", &self.sentence)
78 .field("begin_nodes", &display_pieces(&self.begin_nodes))
79 .field("end_nodes", &display_pieces(&self.end_nodes))
80 .finish()
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct Node {
87 pub(super) id: usize,
89 pub(super) node_id: usize,
91 pos: usize,
92 length: usize,
93 prev: Option<NodeRef>,
94 backtrace_score: f64,
95 score: f64,
96}
97
98impl PartialEq for Node {
99 fn eq(&self, other: &Node) -> bool {
100 self.id == other.id
101 }
102}
103
104impl Node {
105 pub fn new(id: usize, node_id: usize, pos: usize, length: usize, score: f64) -> Self {
106 Self {
107 id,
108 node_id,
109 pos,
110 length,
111 prev: None,
112 score,
113 backtrace_score: 0.0,
114 }
115 }
116}
117
118fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
124 if init_mode {
125 y
126 } else {
127 let (vmin, vmax) = if x > y { (y, x) } else { (x, y) };
128 let k_minus_log_epsilon = 50.0;
129 if vmax > vmin + k_minus_log_epsilon {
130 vmax
131 } else {
132 vmax + ((vmin - vmax).exp() + 1.0).ln()
133 }
134 }
135}
136
137impl<'a> Lattice<'a> {
138 pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Self {
139 let len = sentence.len();
140 let k_reserved_node_size = 16;
141 let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
143 let mut begin_nodes = vec![Vec::with_capacity(k_reserved_node_size); len + 1];
144 let mut end_nodes = vec![Vec::with_capacity(k_reserved_node_size); len + 1];
145
146 let bos = Rc::new(RefCell::new(Node::new(bos_id, 0, 0, 0, 0.0)));
147 let eos = Rc::new(RefCell::new(Node::new(eos_id, 1, len, 0, 0.0)));
148
149 begin_nodes[len].push(Rc::clone(&eos));
150 end_nodes[0].push(Rc::clone(&bos));
151
152 nodes.push(bos);
153 nodes.push(eos);
154
155 Self {
156 sentence,
157 len,
158 nodes,
159 begin_nodes,
160 end_nodes,
161 _bos_id: bos_id,
162 _eos_id: eos_id,
163 }
164 }
165
166 pub fn insert(&mut self, pos: usize, length: usize, score: f64, id: usize) {
167 let node_id = self.nodes.len();
168 let node = Rc::new(RefCell::new(Node::new(id, node_id, pos, length, score)));
169
170 self.begin_nodes[pos].push(Rc::clone(&node));
171 self.end_nodes[pos + length].push(Rc::clone(&node));
172
173 self.nodes.push(node);
174 }
175
176 pub fn viterbi(&mut self) -> Vec<NodeRef> {
177 let len = self.len;
178 let mut pos = 0;
179 while pos <= len {
180 if self.begin_nodes[pos].is_empty() {
181 return vec![];
182 }
183 for rnode in &self.begin_nodes[pos] {
184 rnode.borrow_mut().prev = None;
185 let mut best_score = 0.0;
186 let mut best_node: Option<NodeRef> = None;
187 for lnode in &self.end_nodes[pos] {
188 let score = lnode.borrow().backtrace_score + rnode.borrow().score;
189 if best_node.is_none() || score > best_score {
190 best_node = Some(lnode.clone());
192 best_score = score
193 }
194 }
195 match best_node {
196 Some(bnode) => {
197 rnode.borrow_mut().prev = Some(Rc::clone(&bnode));
198 rnode.borrow_mut().backtrace_score = best_score;
199 }
200 None => return vec![],
201 }
202 }
203 if let Some(c) = self.sentence[pos..].chars().next() {
204 pos += c.len_utf8();
205 } else {
206 break;
207 }
208 }
209
210 let mut results: Vec<NodeRef> = vec![];
211 let root = self.begin_nodes[len][0].borrow();
212 let prev = root.prev.as_ref();
213 if prev.is_none() {
214 return vec![];
215 }
216 let mut node: NodeRef = prev.unwrap().clone();
217 while node.borrow().prev.is_some() {
218 results.push(node.clone());
219 let n = node.borrow().clone();
220 node = n.prev.as_ref().unwrap().clone();
221 }
222 results.reverse();
223 results
224 }
225
226 pub fn piece(&self, node: &Node) -> String {
227 self.sentence[node.pos..node.pos + node.length].to_owned()
228 }
229
230 pub fn tokens(&mut self) -> Vec<String> {
231 self.viterbi()
232 .iter()
233 .map(|node| self.piece(&node.borrow()))
234 .collect()
235 }
236
237 pub fn nbest(&mut self, n: usize) -> Vec<Vec<NodeRef>> {
238 match n {
239 0 => vec![],
240 1 => vec![self.viterbi()],
241 _ => {
242 let mut agenda: Agenda = BinaryHeap::new();
244 let mut hypotheses: Vec<Vec<NodeRef>> = vec![];
245 let eos = self.eos_node();
246 let score = eos.borrow().score;
247 let hypo = Hypothesis::new(eos, None, score, score);
248 agenda.push(hypo);
249
250 self.viterbi();
252
253 while !agenda.is_empty() {
254 let top = Rc::new(RefCell::new(agenda.pop().unwrap()));
255 let node = Rc::clone(&top.borrow().node_ref);
256 if node.borrow().id == self.bos_node().borrow().id {
257 let mut hypothesis = vec![];
258 let mut next: HypothesisRef =
259 Rc::clone(top.borrow().next.as_ref().unwrap());
260 while next.borrow().next.is_some() {
261 hypothesis.push(next.borrow().node_ref.clone());
262 let c: HypothesisRef = next.clone();
263 next = Rc::clone(c.borrow().next.as_ref().unwrap());
265 }
266 hypotheses.push(hypothesis);
267 if hypotheses.len() == n {
268 return hypotheses;
269 }
270 } else {
271 for lnode in &self.end_nodes[node.borrow().pos] {
272 let top_gx = top.borrow().gx;
273 let fx = lnode.borrow().backtrace_score + top_gx;
274 let gx = lnode.borrow().score + top_gx;
275 let hyp =
276 Hypothesis::new(Rc::clone(lnode), Some(Rc::clone(&top)), fx, gx);
277 agenda.push(hyp);
278 }
279 let k_max_agenda_size = 100_000;
283 let k_min_agenda_size = 512;
284 if agenda.len() > k_max_agenda_size {
285 let mut new_agenda = BinaryHeap::new();
286 let len = min(k_min_agenda_size, n * 10);
287 for _i in 0..len {
288 new_agenda.push(agenda.pop().unwrap());
289 }
290 agenda = new_agenda;
291 }
292 }
293 }
294 hypotheses
295 }
296 }
297 }
298
299 pub fn nbest_tokens(&mut self, n: usize) -> Vec<Vec<String>> {
300 self.nbest(n)
301 .iter()
302 .map(|v| v.iter().map(|node| self.piece(&node.borrow())).collect())
303 .collect()
304 }
305
306 pub fn len(&self) -> usize {
307 self.len
308 }
309
310 pub fn is_empty(&self) -> bool {
311 self.len == 0
312 }
313
314 pub fn bos_node(&self) -> NodeRef {
315 Rc::clone(&self.end_nodes[0][0])
316 }
317 pub fn eos_node(&self) -> NodeRef {
318 Rc::clone(&self.begin_nodes[self.len][0])
319 }
320
321 pub fn surface(&self, n: usize) -> &str {
322 match self.sentence.char_indices().nth(n) {
323 Some((pos, _)) => &self.sentence[pos..],
324 None => "",
325 }
326 }
327 pub fn sentence(&self) -> &str {
328 self.sentence
329 }
330
331 pub fn populate_marginal(&self, freq: f64, expected: &mut [f64]) -> f64 {
332 let len = self.len();
333 let n_nodes = self.nodes.len();
334 let mut alpha = vec![0.0; n_nodes];
335 let mut beta = vec![0.0; n_nodes];
336 for pos in 0..=len {
337 for rnode in &self.begin_nodes[pos] {
338 for lnode in &self.end_nodes[pos] {
339 let lid = lnode.borrow().node_id;
340 let rid = rnode.borrow().node_id;
341 alpha[rid] = log_sum_exp(
342 alpha[rid],
343 lnode.borrow().score + alpha[lid],
344 *lnode == self.end_nodes[pos][0],
345 );
346 }
347 }
348 }
349 for pos in (0..=len).rev() {
350 for lnode in &self.end_nodes[pos] {
352 for rnode in &self.begin_nodes[pos] {
353 let lid = lnode.borrow().node_id;
354 let rid = rnode.borrow().node_id;
355 beta[lid] = log_sum_exp(
356 beta[lid],
357 rnode.borrow().score + beta[rid],
358 *rnode == self.begin_nodes[pos][0],
359 );
360 }
361 }
362 }
363
364 let eos_id = self.begin_nodes[len][0].borrow().node_id;
365 let z = alpha[eos_id];
366 for pos in 0..len {
367 for node in &self.begin_nodes[pos] {
368 let node_id = node.borrow().node_id;
369 let id = node.borrow().id;
370 let a = alpha[node_id];
371 let b = beta[node_id];
372 let total = a + node.borrow().score + b - z;
373 let update = freq * total.exp();
374 expected[id] += update;
375 }
376 }
377 freq * z
378 }
379
380 pub fn sample(&self, theta: f64) -> Vec<NodeRef> {
381 let len = self.len();
382 if len == 0 {
383 return vec![];
384 }
385 let mut alpha = vec![0.0; self.nodes.len()];
386 for pos in 0..=len {
387 for rnode in &self.begin_nodes[pos] {
388 for lnode in &self.end_nodes[pos] {
389 let lid = lnode.borrow().node_id;
390 let rid = rnode.borrow().node_id;
391 alpha[rid] = log_sum_exp(
392 alpha[rid],
393 theta * (lnode.borrow().score + alpha[lid]),
394 *lnode == self.end_nodes[pos][0],
395 );
396 }
397 }
398 }
399
400 let mut rng = thread_rng();
401 let mut results: Vec<NodeRef> = vec![];
402 let mut probs: Vec<f64> = vec![];
403 let mut z = alpha[self.eos_node().borrow().node_id];
404 let mut node = self.eos_node();
405 loop {
406 probs.clear();
407 let pos = node.borrow().pos;
408 for lnode in &self.end_nodes[pos] {
409 let lid = lnode.borrow().node_id;
410 probs.push((alpha[lid] + theta * lnode.borrow().score - z).exp())
411 }
412 let dist = WeightedIndex::new(&probs).unwrap();
413 let index = dist.sample(&mut rng);
414 node = Rc::clone(&self.end_nodes[pos][index]);
415 if node == self.bos_node() {
416 break;
417 }
418 z = alpha[node.borrow().node_id];
419 results.push(Rc::clone(&node));
420 }
421 results.reverse();
422 results
423 }
424
425 pub fn sample_token(&self, theta: f64) -> Vec<String> {
426 self.sample(theta)
427 .iter()
428 .map(|node| self.piece(&node.borrow()))
429 .collect()
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use assert_approx_eq::assert_approx_eq;
437
438 #[test]
439 fn set_sentence() {
440 let lattice = Lattice::from("", 1, 2);
441
442 assert_eq!(lattice.len(), 0);
443
444 let lattice = Lattice::from("", 1, 2);
445 assert_eq!(lattice.len(), 0);
446 assert_eq!(lattice.sentence(), "");
447 assert_eq!(lattice.surface(0), "");
448
449 let lattice = Lattice::from("test", 1, 2);
450 assert_eq!(lattice.len(), 4);
451 assert_eq!(lattice.sentence(), "test");
452 assert_eq!(lattice.surface(0), "test");
453 assert_eq!(lattice.surface(1), "est");
454 assert_eq!(lattice.surface(2), "st");
455 assert_eq!(lattice.surface(3), "t");
456
457 let bos = lattice.bos_node();
458 let eos = lattice.eos_node();
459
460 assert_eq!(bos.borrow().id, 1);
461 assert_eq!(eos.borrow().id, 2);
462 assert_eq!(
463 lattice.end_nodes[0].first().unwrap().borrow().id,
464 bos.borrow().id
465 );
466 assert_eq!(
467 lattice.begin_nodes[4].first().unwrap().borrow().id,
468 eos.borrow().id
469 );
470
471 let lattice = Lattice::from("テストab", 1, 2);
472 assert_eq!(lattice.len(), 11);
473 assert_eq!(lattice.sentence(), "テストab");
474 assert_eq!(lattice.surface(0), "テストab");
475 assert_eq!(lattice.surface(1), "ストab");
476 assert_eq!(lattice.surface(2), "トab");
477 assert_eq!(lattice.surface(3), "ab");
478 assert_eq!(lattice.surface(4), "b");
479 }
480
481 #[test]
482 fn insert_test() {
483 let mut lattice = Lattice::from("ABあい", 1, 2);
484
485 lattice.insert(0, 1, 0.0, 3);
486 lattice.insert(1, 1, 0.0, 4);
487 lattice.insert(2, 3, 0.0, 5);
488 lattice.insert(5, 3, 0.0, 6);
489 lattice.insert(0, 2, 0.0, 7);
490 lattice.insert(1, 4, 0.0, 8);
491 lattice.insert(2, 6, 0.0, 9);
492 let node0 = lattice.nodes[2].borrow();
494 let node1 = lattice.nodes[3].borrow();
495 let node2 = lattice.nodes[4].borrow();
496 let node3 = lattice.nodes[5].borrow();
497 let node4 = lattice.nodes[6].borrow();
498 let node5 = lattice.nodes[7].borrow();
499 let node6 = lattice.nodes[8].borrow();
500
501 assert_eq!(lattice.piece(&node0), "A");
502 assert_eq!(lattice.piece(&node1), "B");
503 assert_eq!(lattice.piece(&node2), "あ");
504 assert_eq!(lattice.piece(&node3), "い");
505 assert_eq!(lattice.piece(&node4), "AB");
506 assert_eq!(lattice.piece(&node5), "Bあ");
507 assert_eq!(lattice.piece(&node6), "あい");
508
509 assert_eq!(node0.pos, 0);
510 assert_eq!(node1.pos, 1);
511 assert_eq!(node2.pos, 2);
512 assert_eq!(node3.pos, 5);
513 assert_eq!(node4.pos, 0);
514 assert_eq!(node5.pos, 1);
515 assert_eq!(node6.pos, 2);
516
517 assert_eq!(node0.length, 1);
518 assert_eq!(node1.length, 1);
519 assert_eq!(node2.length, 3);
520 assert_eq!(node3.length, 3);
521 assert_eq!(node4.length, 2);
522 assert_eq!(node5.length, 4);
523 assert_eq!(node6.length, 6);
524
525 assert_eq!(lattice.bos_node().borrow().id, 1);
526 assert_eq!(lattice.eos_node().borrow().id, 2);
527 assert_eq!(node0.id, 3);
528 assert_eq!(node1.id, 4);
529 assert_eq!(node2.id, 5);
530 assert_eq!(node3.id, 6);
531 assert_eq!(node4.id, 7);
532 assert_eq!(node5.id, 8);
533 assert_eq!(node6.id, 9);
534
535 assert_eq!(lattice.begin_nodes[0].len(), 2);
536 assert_eq!(lattice.begin_nodes[1].len(), 2);
537 assert_eq!(lattice.begin_nodes[2].len(), 2);
538 assert_eq!(lattice.begin_nodes[5].len(), 1);
539 assert_eq!(lattice.begin_nodes[8].len(), 1);
540
541 assert_eq!(lattice.end_nodes[0].len(), 1);
542 assert_eq!(lattice.end_nodes[1].len(), 1);
543 assert_eq!(lattice.end_nodes[2].len(), 2);
544 assert_eq!(lattice.end_nodes[5].len(), 2);
545 assert_eq!(lattice.end_nodes[8].len(), 2);
546
547 assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
548 assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
549 assert_eq!(lattice.begin_nodes[1][0].borrow().id, node1.id);
550 assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
551 assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
552 assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
553 assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id);
554 assert_eq!(
555 lattice.eos_node().borrow().id,
556 lattice.begin_nodes[8][0].borrow().id
557 );
558
559 assert_eq!(
560 lattice.bos_node().borrow().id,
561 lattice.end_nodes[0][0].borrow().id
562 );
563 assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
564 assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
565 assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
566 assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id);
567 assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id);
568 assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id);
569 assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id);
570 }
571
572 #[test]
573 fn test_viterbi() {
574 let mut lattice = Lattice::from("ABC", 1, 2);
575 assert_eq!(lattice.viterbi(), vec![]);
576 lattice.insert(0, 1, 0.0, 3);
578 assert_eq!(lattice.viterbi(), vec![]);
579 lattice.insert(1, 1, 0.0, 4);
580 lattice.insert(2, 1, 0.0, 5);
581 assert_eq!(lattice.viterbi().len(), 3);
583 }
584
585 #[test]
586 fn test_viterbi2() {
587 let mut lattice = Lattice::from("ABC", 1, 2);
588
589 lattice.insert(0, 1, 0.0, 3);
590 lattice.insert(1, 1, 0.0, 4);
591 lattice.insert(2, 1, 0.0, 5);
592
593 assert_eq!(lattice.tokens(), ["A", "B", "C"]);
594
595 lattice.insert(0, 2, 2.0, 6);
596 assert_eq!(lattice.tokens(), ["AB", "C"]);
597
598 lattice.insert(1, 2, 5.0, 7);
599 assert_eq!(lattice.tokens(), ["A", "BC"]);
600
601 lattice.insert(0, 3, 10.0, 8);
602 assert_eq!(lattice.tokens(), ["ABC"]);
603 }
604
605 #[test]
606 fn test_nbest() {
607 let mut lattice = Lattice::from("ABC", 1, 2);
608 lattice.insert(0, 1, 0.0, 3);
609 lattice.insert(1, 1, 0.0, 4);
610 lattice.insert(2, 1, 0.0, 5);
611 lattice.insert(0, 2, 2.0, 6);
612 lattice.insert(1, 2, 5.0, 7);
613 lattice.insert(0, 3, 10.0, 8);
614
615 let nbests = lattice.nbest_tokens(10);
616 assert_eq!(
617 nbests,
618 vec![
619 vec!["ABC"],
620 vec!["A", "BC"],
621 vec!["AB", "C"],
622 vec!["A", "B", "C"]
623 ]
624 );
625
626 assert!(lattice.nbest_tokens(0).is_empty());
627 assert_eq!(lattice.nbest_tokens(1), vec![vec!["ABC"]]);
628 }
629 #[test]
630 fn test_log_sum_exp() {
631 let mut x = 0.0;
632
633 let v: Vec<f64> = vec![1.0, 2.0, 3.0];
634 for (i, y) in v.iter().enumerate() {
635 x = log_sum_exp(x, *y, i == 0);
636 }
637 assert_approx_eq!(x, v.iter().map(|n| n.exp()).sum::<f64>().ln(), 0.001);
638 }
639
640 #[test]
641 fn test_populate() {
642 let mut lattice = Lattice::from("ABC", 1, 2);
643 lattice.insert(0, 1, 1.0, 3); lattice.insert(1, 1, 1.2, 4); lattice.insert(2, 1, 2.5, 5); lattice.insert(0, 2, 3.0, 6); lattice.insert(1, 2, 4.0, 7); lattice.insert(0, 3, 2.0, 8); let mut probs = vec![0.0; 9];
651 let p1 = (1.0_f64 + 1.2 + 2.5).exp();
652 let p2 = (3.0_f64 + 2.5).exp();
653 let p3 = (1.0_f64 + 4.0).exp();
654 let p4 = 2.0_f64.exp();
655 let z = p1 + p2 + p3 + p4;
656
657 let log_z = lattice.populate_marginal(1.0, &mut probs);
658
659 assert_approx_eq!(log_z, z.ln(), 0.001);
660 assert_approx_eq!(probs[0], 0.0, 0.001);
661 assert_approx_eq!(probs[1], 0.0, 0.001);
662 assert_approx_eq!(probs[2], 0.0, 0.001);
663 assert_approx_eq!(probs[3], (p1 + p3) / z, 0.001);
664 assert_approx_eq!(probs[4], (p1) / z, 0.001);
665 assert_approx_eq!(probs[5], (p1 + p2) / z, 0.001);
666 assert_approx_eq!(probs[6], (p2) / z, 0.001);
667 assert_approx_eq!(probs[7], (p3) / z, 0.001);
668 assert_approx_eq!(probs[8], (p4) / z, 0.001);
669 }
670}