tokenizers/models/unigram/
lattice.rs

1use rand::distributions::WeightedIndex;
2use rand::prelude::*;
3use std::cell::RefCell;
4use std::cmp::{min, Ordering};
5use std::collections::BinaryHeap;
6use std::rc::Rc;
7
8type NodeRef = Rc<RefCell<Node>>;
9type HypothesisRef = Rc<RefCell<Hypothesis>>;
10type Agenda = BinaryHeap<Hypothesis>;
11
12struct Hypothesis {
13    node_ref: NodeRef,
14    next: Option<HypothesisRef>,
15    fx: f64,
16    gx: f64,
17}
18impl Hypothesis {
19    pub fn new(node_ref: NodeRef, next: Option<HypothesisRef>, fx: f64, gx: f64) -> Self {
20        Self {
21            node_ref,
22            next,
23            fx,
24            gx,
25        }
26    }
27}
28impl PartialEq for Hypothesis {
29    fn eq(&self, other: &Self) -> bool {
30        self.fx == other.fx
31    }
32}
33impl Eq for Hypothesis {}
34impl PartialOrd for Hypothesis {
35    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
36        Some(self.cmp(other))
37    }
38}
39// TODO Maybe use Ordered Floats (https://docs.rs/ordered-float/1.0.2/ordered_float/)
40impl Ord for Hypothesis {
41    fn cmp(&self, other: &Self) -> Ordering {
42        if self.fx < other.fx {
43            Ordering::Less
44        } else {
45            Ordering::Greater
46        }
47    }
48}
49
50/// Structure to implement Viterbi algorithm to find the best encoding, or sample
51/// from all possible encodings of a given sentence.
52#[derive(Debug)]
53pub struct Lattice<'a> {
54    pub(super) sentence: &'a str,
55    len: usize,
56    nodes: Vec<NodeRef>,
57    pub(super) begin_nodes: Vec<Vec<NodeRef>>,
58    pub(super) end_nodes: Vec<Vec<NodeRef>>,
59    _bos_id: usize,
60    _eos_id: usize,
61}
62
63impl std::fmt::Display for Lattice<'_> {
64    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65        let display_pieces = |nodes: &Vec<Vec<NodeRef>>| {
66            nodes
67                .iter()
68                .map(|l| {
69                    l.iter()
70                        .map(|n| self.piece(&n.borrow()))
71                        .collect::<Vec<_>>()
72                })
73                .collect::<Vec<_>>()
74        };
75
76        f.debug_struct("Lattice")
77            .field("sentence", &self.sentence)
78            .field("begin_nodes", &display_pieces(&self.begin_nodes))
79            .field("end_nodes", &display_pieces(&self.end_nodes))
80            .finish()
81    }
82}
83
84/// A node from the lattice, that helps reconstruct the underlying `String`
85#[derive(Debug, Clone)]
86pub struct Node {
87    // Vocabulary id
88    pub(super) id: usize,
89    // Local lattice identifier
90    pub(super) node_id: usize,
91    pos: usize,
92    length: usize,
93    prev: Option<NodeRef>,
94    backtrace_score: f64,
95    score: f64,
96}
97
98impl PartialEq for Node {
99    fn eq(&self, other: &Node) -> bool {
100        self.id == other.id
101    }
102}
103
104impl Node {
105    pub fn new(id: usize, node_id: usize, pos: usize, length: usize, score: f64) -> Self {
106        Self {
107            id,
108            node_id,
109            pos,
110            length,
111            prev: None,
112            score,
113            backtrace_score: 0.0,
114        }
115    }
116}
117
118/// Returns log(exp(x) + exp(y)).
119/// if init_mode is true, returns log(exp(y)) == y.
120/// log(\sum_i exp(a[i])) can be computed as
121/// for (int i = 0; i < a.size(); ++i)
122///   x = LogSumExp(x, a[i], i == 0);
123fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
124    if init_mode {
125        y
126    } else {
127        let (vmin, vmax) = if x > y { (y, x) } else { (x, y) };
128        let k_minus_log_epsilon = 50.0;
129        if vmax > vmin + k_minus_log_epsilon {
130            vmax
131        } else {
132            vmax + ((vmin - vmax).exp() + 1.0).ln()
133        }
134    }
135}
136
137impl<'a> Lattice<'a> {
138    pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Self {
139        let len = sentence.len();
140        let k_reserved_node_size = 16;
141        // We are adding 2 tokens, bos and eos
142        let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
143        let mut begin_nodes = vec![Vec::with_capacity(k_reserved_node_size); len + 1];
144        let mut end_nodes = vec![Vec::with_capacity(k_reserved_node_size); len + 1];
145
146        let bos = Rc::new(RefCell::new(Node::new(bos_id, 0, 0, 0, 0.0)));
147        let eos = Rc::new(RefCell::new(Node::new(eos_id, 1, len, 0, 0.0)));
148
149        begin_nodes[len].push(Rc::clone(&eos));
150        end_nodes[0].push(Rc::clone(&bos));
151
152        nodes.push(bos);
153        nodes.push(eos);
154
155        Self {
156            sentence,
157            len,
158            nodes,
159            begin_nodes,
160            end_nodes,
161            _bos_id: bos_id,
162            _eos_id: eos_id,
163        }
164    }
165
166    pub fn insert(&mut self, pos: usize, length: usize, score: f64, id: usize) {
167        let node_id = self.nodes.len();
168        let node = Rc::new(RefCell::new(Node::new(id, node_id, pos, length, score)));
169
170        self.begin_nodes[pos].push(Rc::clone(&node));
171        self.end_nodes[pos + length].push(Rc::clone(&node));
172
173        self.nodes.push(node);
174    }
175
176    pub fn viterbi(&mut self) -> Vec<NodeRef> {
177        let len = self.len;
178        let mut pos = 0;
179        while pos <= len {
180            if self.begin_nodes[pos].is_empty() {
181                return vec![];
182            }
183            for rnode in &self.begin_nodes[pos] {
184                rnode.borrow_mut().prev = None;
185                let mut best_score = 0.0;
186                let mut best_node: Option<NodeRef> = None;
187                for lnode in &self.end_nodes[pos] {
188                    let score = lnode.borrow().backtrace_score + rnode.borrow().score;
189                    if best_node.is_none() || score > best_score {
190                        // TODO can we remove this clone ?
191                        best_node = Some(lnode.clone());
192                        best_score = score
193                    }
194                }
195                match best_node {
196                    Some(bnode) => {
197                        rnode.borrow_mut().prev = Some(Rc::clone(&bnode));
198                        rnode.borrow_mut().backtrace_score = best_score;
199                    }
200                    None => return vec![],
201                }
202            }
203            if let Some(c) = self.sentence[pos..].chars().next() {
204                pos += c.len_utf8();
205            } else {
206                break;
207            }
208        }
209
210        let mut results: Vec<NodeRef> = vec![];
211        let root = self.begin_nodes[len][0].borrow();
212        let prev = root.prev.as_ref();
213        if prev.is_none() {
214            return vec![];
215        }
216        let mut node: NodeRef = prev.unwrap().clone();
217        while node.borrow().prev.is_some() {
218            results.push(node.clone());
219            let n = node.borrow().clone();
220            node = n.prev.as_ref().unwrap().clone();
221        }
222        results.reverse();
223        results
224    }
225
226    pub fn piece(&self, node: &Node) -> String {
227        self.sentence[node.pos..node.pos + node.length].to_owned()
228    }
229
230    pub fn tokens(&mut self) -> Vec<String> {
231        self.viterbi()
232            .iter()
233            .map(|node| self.piece(&node.borrow()))
234            .collect()
235    }
236
237    pub fn nbest(&mut self, n: usize) -> Vec<Vec<NodeRef>> {
238        match n {
239            0 => vec![],
240            1 => vec![self.viterbi()],
241            _ => {
242                // let k_reserved_hypothesis_size = 512;
243                let mut agenda: Agenda = BinaryHeap::new();
244                let mut hypotheses: Vec<Vec<NodeRef>> = vec![];
245                let eos = self.eos_node();
246                let score = eos.borrow().score;
247                let hypo = Hypothesis::new(eos, None, score, score);
248                agenda.push(hypo);
249
250                // Fill backtrace scores
251                self.viterbi();
252
253                while !agenda.is_empty() {
254                    let top = Rc::new(RefCell::new(agenda.pop().unwrap()));
255                    let node = Rc::clone(&top.borrow().node_ref);
256                    if node.borrow().id == self.bos_node().borrow().id {
257                        let mut hypothesis = vec![];
258                        let mut next: HypothesisRef =
259                            Rc::clone(top.borrow().next.as_ref().unwrap());
260                        while next.borrow().next.is_some() {
261                            hypothesis.push(next.borrow().node_ref.clone());
262                            let c: HypothesisRef = next.clone();
263                            // let c: Ref<Hypothesis> = next.clone().borrow();
264                            next = Rc::clone(c.borrow().next.as_ref().unwrap());
265                        }
266                        hypotheses.push(hypothesis);
267                        if hypotheses.len() == n {
268                            return hypotheses;
269                        }
270                    } else {
271                        for lnode in &self.end_nodes[node.borrow().pos] {
272                            let top_gx = top.borrow().gx;
273                            let fx = lnode.borrow().backtrace_score + top_gx;
274                            let gx = lnode.borrow().score + top_gx;
275                            let hyp =
276                                Hypothesis::new(Rc::clone(lnode), Some(Rc::clone(&top)), fx, gx);
277                            agenda.push(hyp);
278                        }
279                        // When the input is too long or contains duplicated phrases,
280                        // `agenda` will get extremely big. Here we avoid this case by
281                        // dynamically shrinking the agenda.
282                        let k_max_agenda_size = 100_000;
283                        let k_min_agenda_size = 512;
284                        if agenda.len() > k_max_agenda_size {
285                            let mut new_agenda = BinaryHeap::new();
286                            let len = min(k_min_agenda_size, n * 10);
287                            for _i in 0..len {
288                                new_agenda.push(agenda.pop().unwrap());
289                            }
290                            agenda = new_agenda;
291                        }
292                    }
293                }
294                hypotheses
295            }
296        }
297    }
298
299    pub fn nbest_tokens(&mut self, n: usize) -> Vec<Vec<String>> {
300        self.nbest(n)
301            .iter()
302            .map(|v| v.iter().map(|node| self.piece(&node.borrow())).collect())
303            .collect()
304    }
305
306    pub fn len(&self) -> usize {
307        self.len
308    }
309
310    pub fn is_empty(&self) -> bool {
311        self.len == 0
312    }
313
314    pub fn bos_node(&self) -> NodeRef {
315        Rc::clone(&self.end_nodes[0][0])
316    }
317    pub fn eos_node(&self) -> NodeRef {
318        Rc::clone(&self.begin_nodes[self.len][0])
319    }
320
321    pub fn surface(&self, n: usize) -> &str {
322        match self.sentence.char_indices().nth(n) {
323            Some((pos, _)) => &self.sentence[pos..],
324            None => "",
325        }
326    }
327    pub fn sentence(&self) -> &str {
328        self.sentence
329    }
330
331    pub fn populate_marginal(&self, freq: f64, expected: &mut [f64]) -> f64 {
332        let len = self.len();
333        let n_nodes = self.nodes.len();
334        let mut alpha = vec![0.0; n_nodes];
335        let mut beta = vec![0.0; n_nodes];
336        for pos in 0..=len {
337            for rnode in &self.begin_nodes[pos] {
338                for lnode in &self.end_nodes[pos] {
339                    let lid = lnode.borrow().node_id;
340                    let rid = rnode.borrow().node_id;
341                    alpha[rid] = log_sum_exp(
342                        alpha[rid],
343                        lnode.borrow().score + alpha[lid],
344                        *lnode == self.end_nodes[pos][0],
345                    );
346                }
347            }
348        }
349        for pos in (0..=len).rev() {
350            // let rpos = len - pos;
351            for lnode in &self.end_nodes[pos] {
352                for rnode in &self.begin_nodes[pos] {
353                    let lid = lnode.borrow().node_id;
354                    let rid = rnode.borrow().node_id;
355                    beta[lid] = log_sum_exp(
356                        beta[lid],
357                        rnode.borrow().score + beta[rid],
358                        *rnode == self.begin_nodes[pos][0],
359                    );
360                }
361            }
362        }
363
364        let eos_id = self.begin_nodes[len][0].borrow().node_id;
365        let z = alpha[eos_id];
366        for pos in 0..len {
367            for node in &self.begin_nodes[pos] {
368                let node_id = node.borrow().node_id;
369                let id = node.borrow().id;
370                let a = alpha[node_id];
371                let b = beta[node_id];
372                let total = a + node.borrow().score + b - z;
373                let update = freq * total.exp();
374                expected[id] += update;
375            }
376        }
377        freq * z
378    }
379
380    pub fn sample(&self, theta: f64) -> Vec<NodeRef> {
381        let len = self.len();
382        if len == 0 {
383            return vec![];
384        }
385        let mut alpha = vec![0.0; self.nodes.len()];
386        for pos in 0..=len {
387            for rnode in &self.begin_nodes[pos] {
388                for lnode in &self.end_nodes[pos] {
389                    let lid = lnode.borrow().node_id;
390                    let rid = rnode.borrow().node_id;
391                    alpha[rid] = log_sum_exp(
392                        alpha[rid],
393                        theta * (lnode.borrow().score + alpha[lid]),
394                        *lnode == self.end_nodes[pos][0],
395                    );
396                }
397            }
398        }
399
400        let mut rng = thread_rng();
401        let mut results: Vec<NodeRef> = vec![];
402        let mut probs: Vec<f64> = vec![];
403        let mut z = alpha[self.eos_node().borrow().node_id];
404        let mut node = self.eos_node();
405        loop {
406            probs.clear();
407            let pos = node.borrow().pos;
408            for lnode in &self.end_nodes[pos] {
409                let lid = lnode.borrow().node_id;
410                probs.push((alpha[lid] + theta * lnode.borrow().score - z).exp())
411            }
412            let dist = WeightedIndex::new(&probs).unwrap();
413            let index = dist.sample(&mut rng);
414            node = Rc::clone(&self.end_nodes[pos][index]);
415            if node == self.bos_node() {
416                break;
417            }
418            z = alpha[node.borrow().node_id];
419            results.push(Rc::clone(&node));
420        }
421        results.reverse();
422        results
423    }
424
425    pub fn sample_token(&self, theta: f64) -> Vec<String> {
426        self.sample(theta)
427            .iter()
428            .map(|node| self.piece(&node.borrow()))
429            .collect()
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use assert_approx_eq::assert_approx_eq;
437
438    #[test]
439    fn set_sentence() {
440        let lattice = Lattice::from("", 1, 2);
441
442        assert_eq!(lattice.len(), 0);
443
444        let lattice = Lattice::from("", 1, 2);
445        assert_eq!(lattice.len(), 0);
446        assert_eq!(lattice.sentence(), "");
447        assert_eq!(lattice.surface(0), "");
448
449        let lattice = Lattice::from("test", 1, 2);
450        assert_eq!(lattice.len(), 4);
451        assert_eq!(lattice.sentence(), "test");
452        assert_eq!(lattice.surface(0), "test");
453        assert_eq!(lattice.surface(1), "est");
454        assert_eq!(lattice.surface(2), "st");
455        assert_eq!(lattice.surface(3), "t");
456
457        let bos = lattice.bos_node();
458        let eos = lattice.eos_node();
459
460        assert_eq!(bos.borrow().id, 1);
461        assert_eq!(eos.borrow().id, 2);
462        assert_eq!(
463            lattice.end_nodes[0].first().unwrap().borrow().id,
464            bos.borrow().id
465        );
466        assert_eq!(
467            lattice.begin_nodes[4].first().unwrap().borrow().id,
468            eos.borrow().id
469        );
470
471        let lattice = Lattice::from("テストab", 1, 2);
472        assert_eq!(lattice.len(), 11);
473        assert_eq!(lattice.sentence(), "テストab");
474        assert_eq!(lattice.surface(0), "テストab");
475        assert_eq!(lattice.surface(1), "ストab");
476        assert_eq!(lattice.surface(2), "トab");
477        assert_eq!(lattice.surface(3), "ab");
478        assert_eq!(lattice.surface(4), "b");
479    }
480
481    #[test]
482    fn insert_test() {
483        let mut lattice = Lattice::from("ABあい", 1, 2);
484
485        lattice.insert(0, 1, 0.0, 3);
486        lattice.insert(1, 1, 0.0, 4);
487        lattice.insert(2, 3, 0.0, 5);
488        lattice.insert(5, 3, 0.0, 6);
489        lattice.insert(0, 2, 0.0, 7);
490        lattice.insert(1, 4, 0.0, 8);
491        lattice.insert(2, 6, 0.0, 9);
492        // 0 & 1 are bos and eos
493        let node0 = lattice.nodes[2].borrow();
494        let node1 = lattice.nodes[3].borrow();
495        let node2 = lattice.nodes[4].borrow();
496        let node3 = lattice.nodes[5].borrow();
497        let node4 = lattice.nodes[6].borrow();
498        let node5 = lattice.nodes[7].borrow();
499        let node6 = lattice.nodes[8].borrow();
500
501        assert_eq!(lattice.piece(&node0), "A");
502        assert_eq!(lattice.piece(&node1), "B");
503        assert_eq!(lattice.piece(&node2), "あ");
504        assert_eq!(lattice.piece(&node3), "い");
505        assert_eq!(lattice.piece(&node4), "AB");
506        assert_eq!(lattice.piece(&node5), "Bあ");
507        assert_eq!(lattice.piece(&node6), "あい");
508
509        assert_eq!(node0.pos, 0);
510        assert_eq!(node1.pos, 1);
511        assert_eq!(node2.pos, 2);
512        assert_eq!(node3.pos, 5);
513        assert_eq!(node4.pos, 0);
514        assert_eq!(node5.pos, 1);
515        assert_eq!(node6.pos, 2);
516
517        assert_eq!(node0.length, 1);
518        assert_eq!(node1.length, 1);
519        assert_eq!(node2.length, 3);
520        assert_eq!(node3.length, 3);
521        assert_eq!(node4.length, 2);
522        assert_eq!(node5.length, 4);
523        assert_eq!(node6.length, 6);
524
525        assert_eq!(lattice.bos_node().borrow().id, 1);
526        assert_eq!(lattice.eos_node().borrow().id, 2);
527        assert_eq!(node0.id, 3);
528        assert_eq!(node1.id, 4);
529        assert_eq!(node2.id, 5);
530        assert_eq!(node3.id, 6);
531        assert_eq!(node4.id, 7);
532        assert_eq!(node5.id, 8);
533        assert_eq!(node6.id, 9);
534
535        assert_eq!(lattice.begin_nodes[0].len(), 2);
536        assert_eq!(lattice.begin_nodes[1].len(), 2);
537        assert_eq!(lattice.begin_nodes[2].len(), 2);
538        assert_eq!(lattice.begin_nodes[5].len(), 1);
539        assert_eq!(lattice.begin_nodes[8].len(), 1);
540
541        assert_eq!(lattice.end_nodes[0].len(), 1);
542        assert_eq!(lattice.end_nodes[1].len(), 1);
543        assert_eq!(lattice.end_nodes[2].len(), 2);
544        assert_eq!(lattice.end_nodes[5].len(), 2);
545        assert_eq!(lattice.end_nodes[8].len(), 2);
546
547        assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
548        assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
549        assert_eq!(lattice.begin_nodes[1][0].borrow().id, node1.id);
550        assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
551        assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
552        assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
553        assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id);
554        assert_eq!(
555            lattice.eos_node().borrow().id,
556            lattice.begin_nodes[8][0].borrow().id
557        );
558
559        assert_eq!(
560            lattice.bos_node().borrow().id,
561            lattice.end_nodes[0][0].borrow().id
562        );
563        assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
564        assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
565        assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
566        assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id);
567        assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id);
568        assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id);
569        assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id);
570    }
571
572    #[test]
573    fn test_viterbi() {
574        let mut lattice = Lattice::from("ABC", 1, 2);
575        assert_eq!(lattice.viterbi(), vec![]);
576        // Still incomplete
577        lattice.insert(0, 1, 0.0, 3);
578        assert_eq!(lattice.viterbi(), vec![]);
579        lattice.insert(1, 1, 0.0, 4);
580        lattice.insert(2, 1, 0.0, 5);
581        // XXX: In sentence piece this is not tested, still incomplete ?
582        assert_eq!(lattice.viterbi().len(), 3);
583    }
584
585    #[test]
586    fn test_viterbi2() {
587        let mut lattice = Lattice::from("ABC", 1, 2);
588
589        lattice.insert(0, 1, 0.0, 3);
590        lattice.insert(1, 1, 0.0, 4);
591        lattice.insert(2, 1, 0.0, 5);
592
593        assert_eq!(lattice.tokens(), ["A", "B", "C"]);
594
595        lattice.insert(0, 2, 2.0, 6);
596        assert_eq!(lattice.tokens(), ["AB", "C"]);
597
598        lattice.insert(1, 2, 5.0, 7);
599        assert_eq!(lattice.tokens(), ["A", "BC"]);
600
601        lattice.insert(0, 3, 10.0, 8);
602        assert_eq!(lattice.tokens(), ["ABC"]);
603    }
604
605    #[test]
606    fn test_nbest() {
607        let mut lattice = Lattice::from("ABC", 1, 2);
608        lattice.insert(0, 1, 0.0, 3);
609        lattice.insert(1, 1, 0.0, 4);
610        lattice.insert(2, 1, 0.0, 5);
611        lattice.insert(0, 2, 2.0, 6);
612        lattice.insert(1, 2, 5.0, 7);
613        lattice.insert(0, 3, 10.0, 8);
614
615        let nbests = lattice.nbest_tokens(10);
616        assert_eq!(
617            nbests,
618            vec![
619                vec!["ABC"],
620                vec!["A", "BC"],
621                vec!["AB", "C"],
622                vec!["A", "B", "C"]
623            ]
624        );
625
626        assert!(lattice.nbest_tokens(0).is_empty());
627        assert_eq!(lattice.nbest_tokens(1), vec![vec!["ABC"]]);
628    }
629    #[test]
630    fn test_log_sum_exp() {
631        let mut x = 0.0;
632
633        let v: Vec<f64> = vec![1.0, 2.0, 3.0];
634        for (i, y) in v.iter().enumerate() {
635            x = log_sum_exp(x, *y, i == 0);
636        }
637        assert_approx_eq!(x, v.iter().map(|n| n.exp()).sum::<f64>().ln(), 0.001);
638    }
639
640    #[test]
641    fn test_populate() {
642        let mut lattice = Lattice::from("ABC", 1, 2);
643        lattice.insert(0, 1, 1.0, 3); // A
644        lattice.insert(1, 1, 1.2, 4); // B
645        lattice.insert(2, 1, 2.5, 5); // C
646        lattice.insert(0, 2, 3.0, 6); // AB
647        lattice.insert(1, 2, 4.0, 7); // BC
648        lattice.insert(0, 3, 2.0, 8); // ABC
649
650        let mut probs = vec![0.0; 9];
651        let p1 = (1.0_f64 + 1.2 + 2.5).exp();
652        let p2 = (3.0_f64 + 2.5).exp();
653        let p3 = (1.0_f64 + 4.0).exp();
654        let p4 = 2.0_f64.exp();
655        let z = p1 + p2 + p3 + p4;
656
657        let log_z = lattice.populate_marginal(1.0, &mut probs);
658
659        assert_approx_eq!(log_z, z.ln(), 0.001);
660        assert_approx_eq!(probs[0], 0.0, 0.001);
661        assert_approx_eq!(probs[1], 0.0, 0.001);
662        assert_approx_eq!(probs[2], 0.0, 0.001);
663        assert_approx_eq!(probs[3], (p1 + p3) / z, 0.001);
664        assert_approx_eq!(probs[4], (p1) / z, 0.001);
665        assert_approx_eq!(probs[5], (p1 + p2) / z, 0.001);
666        assert_approx_eq!(probs[6], (p2) / z, 0.001);
667        assert_approx_eq!(probs[7], (p3) / z, 0.001);
668        assert_approx_eq!(probs[8], (p4) / z, 0.001);
669    }
670}