rustfst/algorithms/queues/
shortest_first_queue.rs1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::fmt::Formatter;
4
5use anyhow::Result;
6use binary_heap_plus::{BinaryHeap, FnComparator};
7
8use crate::algorithms::{Queue, QueueType};
9use crate::semirings::Semiring;
10use crate::StateId;
11
12#[derive(Clone)]
13pub struct StateWeightCompare<W: Semiring, C: Clone + Fn(&W, &W) -> Result<bool>> {
14 less: C,
15 weights: Vec<W>,
16}
17
18impl<W: Semiring, C: Clone + Fn(&W, &W) -> Result<bool>> StateWeightCompare<W, C> {
19 pub fn new(weights: Vec<W>, less: C) -> Self {
20 Self { less, weights }
21 }
22
23 pub fn compare(&self, s1: StateId, s2: StateId) -> Result<bool> {
24 (self.less)(&self.weights[s1 as usize], &self.weights[s2 as usize])
25 }
26}
27
28pub fn natural_less<W: Semiring>(w1: &W, w2: &W) -> Result<bool> {
29 Ok((&w1.plus(w2)? == w1) && (w1 != w2))
30}
31
32#[derive(Clone)]
33pub struct ShortestFirstQueue<C: Clone + FnMut(&StateId, &StateId) -> Ordering> {
34 heap: BinaryHeap<StateId, FnComparator<C>>,
35}
36
37impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> Debug for ShortestFirstQueue<C> {
38 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
39 f.write_str(format!("ShortestFirstQueue {{ heap: {:?} }}", self.heap).as_str())
40 }
41}
42
43impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> ShortestFirstQueue<C> {
44 pub fn new(c: C) -> Self {
45 Self {
46 heap: BinaryHeap::new_by(c),
47 }
48 }
49}
50
51impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> Queue for ShortestFirstQueue<C> {
52 fn head(&mut self) -> Option<StateId> {
53 self.heap.peek().cloned()
54 }
55
56 fn enqueue(&mut self, state: StateId) {
57 self.heap.push(state);
58 }
59
60 fn dequeue(&mut self) -> Option<StateId> {
61 self.heap.pop()
62 }
63
64 fn update(&mut self, _state: StateId) {
65 unimplemented!()
66 }
67
68 fn is_empty(&self) -> bool {
69 self.heap.is_empty()
70 }
71
72 fn clear(&mut self) {
73 self.heap.clear()
74 }
75
76 fn queue_type(&self) -> QueueType {
77 QueueType::ShortestFirstQueue
78 }
79}
80
81#[derive(Debug)]
82pub struct NaturalShortestFirstQueue {
83 queue: Box<dyn Queue>,
84}
85
86impl NaturalShortestFirstQueue {
87 pub fn new<W: 'static + Semiring>(weights: Vec<W>) -> Self {
88 let a = StateWeightCompare::new(weights, natural_less);
89 let heap = ShortestFirstQueue::new(move |v1, v2| {
90 if a.compare(*v1, *v2).unwrap() {
91 Ordering::Less
92 } else {
93 Ordering::Greater
94 }
95 });
96 NaturalShortestFirstQueue {
97 queue: Box::new(heap),
98 }
99 }
100}
101
102impl Queue for NaturalShortestFirstQueue {
103 fn head(&mut self) -> Option<StateId> {
104 self.queue.head()
105 }
106
107 fn enqueue(&mut self, state: StateId) {
108 self.queue.enqueue(state)
109 }
110
111 fn dequeue(&mut self) -> Option<StateId> {
112 self.queue.dequeue()
113 }
114
115 fn update(&mut self, state: StateId) {
116 self.queue.update(state)
117 }
118
119 fn is_empty(&self) -> bool {
120 self.queue.is_empty()
121 }
122
123 fn clear(&mut self) {
124 self.queue.clear()
125 }
126
127 fn queue_type(&self) -> QueueType {
128 self.queue.queue_type()
129 }
130}