rustfst/algorithms/queues/
shortest_first_queue.rs

1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::fmt::Formatter;
4
5use anyhow::Result;
6use binary_heap_plus::{BinaryHeap, FnComparator};
7
8use crate::algorithms::{Queue, QueueType};
9use crate::semirings::Semiring;
10use crate::StateId;
11
12#[derive(Clone)]
13pub struct StateWeightCompare<W: Semiring, C: Clone + Fn(&W, &W) -> Result<bool>> {
14    less: C,
15    weights: Vec<W>,
16}
17
18impl<W: Semiring, C: Clone + Fn(&W, &W) -> Result<bool>> StateWeightCompare<W, C> {
19    pub fn new(weights: Vec<W>, less: C) -> Self {
20        Self { less, weights }
21    }
22
23    pub fn compare(&self, s1: StateId, s2: StateId) -> Result<bool> {
24        (self.less)(&self.weights[s1 as usize], &self.weights[s2 as usize])
25    }
26}
27
28pub fn natural_less<W: Semiring>(w1: &W, w2: &W) -> Result<bool> {
29    Ok((&w1.plus(w2)? == w1) && (w1 != w2))
30}
31
32#[derive(Clone)]
33pub struct ShortestFirstQueue<C: Clone + FnMut(&StateId, &StateId) -> Ordering> {
34    heap: BinaryHeap<StateId, FnComparator<C>>,
35}
36
37impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> Debug for ShortestFirstQueue<C> {
38    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
39        f.write_str(format!("ShortestFirstQueue {{ heap: {:?} }}", self.heap).as_str())
40    }
41}
42
43impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> ShortestFirstQueue<C> {
44    pub fn new(c: C) -> Self {
45        Self {
46            heap: BinaryHeap::new_by(c),
47        }
48    }
49}
50
51impl<C: Clone + FnMut(&StateId, &StateId) -> Ordering> Queue for ShortestFirstQueue<C> {
52    fn head(&mut self) -> Option<StateId> {
53        self.heap.peek().cloned()
54    }
55
56    fn enqueue(&mut self, state: StateId) {
57        self.heap.push(state);
58    }
59
60    fn dequeue(&mut self) -> Option<StateId> {
61        self.heap.pop()
62    }
63
64    fn update(&mut self, _state: StateId) {
65        unimplemented!()
66    }
67
68    fn is_empty(&self) -> bool {
69        self.heap.is_empty()
70    }
71
72    fn clear(&mut self) {
73        self.heap.clear()
74    }
75
76    fn queue_type(&self) -> QueueType {
77        QueueType::ShortestFirstQueue
78    }
79}
80
81#[derive(Debug)]
82pub struct NaturalShortestFirstQueue {
83    queue: Box<dyn Queue>,
84}
85
86impl NaturalShortestFirstQueue {
87    pub fn new<W: 'static + Semiring>(weights: Vec<W>) -> Self {
88        let a = StateWeightCompare::new(weights, natural_less);
89        let heap = ShortestFirstQueue::new(move |v1, v2| {
90            if a.compare(*v1, *v2).unwrap() {
91                Ordering::Less
92            } else {
93                Ordering::Greater
94            }
95        });
96        NaturalShortestFirstQueue {
97            queue: Box::new(heap),
98        }
99    }
100}
101
102impl Queue for NaturalShortestFirstQueue {
103    fn head(&mut self) -> Option<StateId> {
104        self.queue.head()
105    }
106
107    fn enqueue(&mut self, state: StateId) {
108        self.queue.enqueue(state)
109    }
110
111    fn dequeue(&mut self) -> Option<StateId> {
112        self.queue.dequeue()
113    }
114
115    fn update(&mut self, state: StateId) {
116        self.queue.update(state)
117    }
118
119    fn is_empty(&self) -> bool {
120        self.queue.is_empty()
121    }
122
123    fn clear(&mut self) {
124        self.queue.clear()
125    }
126
127    fn queue_type(&self) -> QueueType {
128        self.queue.queue_type()
129    }
130}