toolbox_rs/
addressable_binary_heap.rs

1use core::hash::Hash;
2use fxhash::FxHashMap;
3use num::{Bounded, Integer};
4use std::fmt::Debug;
5
6struct HeapNode<NodeID: Copy + Integer, Weight: Bounded + Copy + Integer + Debug, Data> {
7    node: NodeID,
8    key: usize,
9    weight: Weight,
10    data: Data,
11}
12
13impl<NodeID: Copy + Integer, Weight: Bounded + Copy + Integer + Debug, Data>
14    HeapNode<NodeID, Weight, Data>
15{
16    fn new(node: NodeID, key: usize, weight: Weight, data: Data) -> Self {
17        Self {
18            node,
19            key,
20            weight,
21            data,
22        }
23    }
24}
25#[derive(Clone, Copy)]
26struct HeapElement<Weight: Bounded + Copy + Integer + Debug> {
27    index: usize,
28    weight: Weight,
29}
30
31impl<Weight: Bounded + Copy + Integer + Debug> Default for HeapElement<Weight> {
32    fn default() -> Self {
33        HeapElement::new(usize::MAX, Weight::min_value())
34    }
35}
36
37impl<Weight: Bounded + Copy + Integer + Debug> HeapElement<Weight> {
38    fn new(index: usize, weight: Weight) -> Self {
39        Self { index, weight }
40    }
41}
42
43pub struct AddressableHeap<NodeID: Copy + Integer, Weight: Bounded + Copy + Integer + Debug, Data> {
44    heap: Vec<HeapElement<Weight>>,
45    inserted_nodes: Vec<HeapNode<NodeID, Weight, Data>>,
46    node_index: FxHashMap<NodeID, usize>,
47}
48
49impl<NodeID: Copy + Hash + Integer, Weight: Bounded + Copy + Integer + Debug, Data> Default
50    for AddressableHeap<NodeID, Weight, Data>
51{
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl<NodeID: Copy + Hash + Integer, Weight: Bounded + Copy + Integer + Debug, Data>
58    AddressableHeap<NodeID, Weight, Data>
59{
60    pub fn new() -> AddressableHeap<NodeID, Weight, Data> {
61        AddressableHeap {
62            heap: vec![HeapElement::default()],
63            inserted_nodes: Vec::new(),
64            node_index: FxHashMap::default(),
65        }
66    }
67
68    pub fn clear(&mut self) {
69        self.heap.clear();
70        self.inserted_nodes.clear();
71        self.heap.push(HeapElement::default());
72        self.node_index.clear();
73    }
74
75    pub fn len(&self) -> usize {
76        self.heap.len() - 1
77    }
78
79    pub fn is_empty(&self) -> bool {
80        self.len() == 0
81    }
82
83    /// return the number of inserted elements since the last time queue was
84    /// cleared. Note that this is not the number of elements currently in
85    /// the heap, nor the number of removed elements.
86    pub fn inserted_len(&self) -> usize {
87        self.inserted_nodes.len()
88    }
89
90    pub fn insert(&mut self, node: NodeID, weight: Weight, data: Data) {
91        let index = self.inserted_nodes.len();
92        let element = HeapElement { index, weight };
93        let key = self.heap.len();
94        self.heap.push(element);
95        self.inserted_nodes
96            .push(HeapNode::new(node, key, weight, data));
97        self.node_index.insert(node, index);
98        self.up_heap(key);
99    }
100
101    pub fn data(&self, node: NodeID) -> &Data {
102        let index = self.node_index.get(&node).unwrap();
103        &self.inserted_nodes.get(*index).unwrap().data
104    }
105
106    pub fn data_mut(&mut self, node: NodeID) -> &mut Data {
107        let index = self.node_index.get(&node).unwrap();
108        &mut self.inserted_nodes.get_mut(*index).unwrap().data
109    }
110
111    pub fn weight(&self, node: NodeID) -> Weight {
112        let index = self.node_index.get(&node);
113        if let Some(index) = index {
114            self.inserted_nodes.get(*index).unwrap().weight
115        } else {
116            Weight::max_value()
117        }
118    }
119
120    pub fn removed(&self, node: NodeID) -> bool {
121        let index = self.node_index.get(&node);
122        if let Some(index) = index {
123            self.inserted_nodes.get(*index).unwrap().key == 0
124        } else {
125            false
126        }
127    }
128
129    pub fn contains(&self, node: NodeID) -> bool {
130        let index = self.node_index.get(&node);
131        if let Some(index) = index {
132            self.inserted_nodes.get(*index).unwrap().key != 0
133        } else {
134            false
135        }
136    }
137
138    pub fn inserted(&self, node: NodeID) -> bool {
139        let index = self.node_index.get(&node);
140        if let Some(index) = index {
141            debug_assert!(index < &self.inserted_nodes.len());
142            self.inserted_nodes.get(*index).unwrap().node == node
143        } else {
144            false
145        }
146    }
147
148    /// Returns the node with minimum weight without removing it from the heap.
149    /// Panics if the heap is empty.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// use toolbox_rs::addressable_binary_heap::AddressableHeap;
155    /// let mut heap = AddressableHeap::new();
156    /// heap.insert(1, 1, 0);
157    /// heap.insert(2, 2, 0);
158    /// assert_eq!(heap.min(), 1);
159    /// ```
160    pub fn min(&self) -> NodeID {
161        let index = self.heap[1].index;
162        self.inserted_nodes[index].node
163    }
164
165    /// Returns the node with minimum weight without removing it, or None if the heap is empty.
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// use toolbox_rs::addressable_binary_heap::AddressableHeap;
171    /// let mut heap = AddressableHeap::new();
172    ///
173    /// assert_eq!(heap.pop(), None);  // Empty heap
174    ///
175    /// heap.insert(1, 1, 0);
176    /// heap.insert(2, 2, 0);
177    /// assert_eq!(heap.pop(), Some(1));  // Returns min without removing
178    /// assert_eq!(heap.pop(), Some(1));  // Still returns 1 as it wasn't removed
179    /// ```
180    #[inline]
181    pub fn pop(&mut self) -> Option<NodeID> {
182        if self.is_empty() {
183            return None;
184        }
185        Some(self.min())
186    }
187
188    pub fn delete_min(&mut self) -> NodeID {
189        let removed_index = self.heap[1].index;
190        let last_index = self.heap.len() - 1;
191        self.heap.swap(1, last_index);
192
193        self.heap.pop();
194        if self.heap.len() > 1 {
195            self.down_heap(1);
196        }
197        self.inserted_nodes[removed_index].key = 0;
198        self.inserted_nodes[removed_index].node
199    }
200
201    pub fn flush(&mut self) {
202        (1..(self.heap.len() - 1)).rev().for_each(|i| {
203            let element = &self.heap[i];
204            self.inserted_nodes[element.index].key = 0;
205        });
206        self.heap.truncate(1);
207        self.heap[0].weight = Weight::max_value();
208    }
209
210    pub fn decrease_key(&mut self, node: NodeID, weight: Weight) {
211        let index = self.node_index[&node];
212        let key = self.inserted_nodes[index].key;
213
214        self.inserted_nodes[index].weight = weight;
215        self.up_heap(key);
216    }
217
218    pub fn decrease_key_and_update_data(&mut self, node: NodeID, weight: Weight, data: Data) {
219        self.decrease_key(node, weight);
220        (*self.data_mut(node)) = data;
221    }
222
223    fn down_heap(&mut self, mut key: usize) {
224        let dropping_index = self.heap[key].index;
225        let weight = self.heap[key].weight;
226
227        let mut next_key = key << 1;
228        while next_key < self.heap.len() {
229            let next_key_sibling = next_key + 1;
230            if next_key_sibling < self.heap.len()
231                && self.heap[next_key].weight > self.heap[next_key_sibling].weight
232            {
233                next_key = next_key_sibling;
234            }
235            if weight <= self.heap[next_key].weight {
236                break;
237            }
238            self.heap[key] = self.heap[next_key];
239            self.inserted_nodes[self.heap[key].index].key = key;
240            key = next_key;
241            next_key <<= 1;
242        }
243        self.heap[key] = HeapElement {
244            index: dropping_index,
245            weight,
246        };
247        self.inserted_nodes[dropping_index].key = key;
248    }
249
250    pub fn up_heap(&mut self, mut key: usize) {
251        let rising_index = self.heap[key].index;
252        let weight = self.heap[key].weight;
253
254        let mut next_key = key >> 1;
255
256        while self.heap[next_key].weight > weight {
257            self.heap[key] = self.heap[next_key];
258            let index = self.heap[key].index;
259
260            self.inserted_nodes[index].key = key;
261            key = next_key;
262            next_key >>= 1;
263        }
264        self.heap[key].index = rising_index;
265        self.heap[key].weight = weight;
266        self.inserted_nodes[rising_index].key = key;
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use rand::{Rng, SeedableRng, prelude::StdRng};
273
274    use crate::addressable_binary_heap::AddressableHeap;
275    type Heap = AddressableHeap<i32, i32, i32>;
276
277    #[test]
278    fn empty() {
279        let heap = Heap::new();
280        assert!(heap.is_empty());
281    }
282
283    #[test]
284    fn insert_size() {
285        let mut heap = Heap::new();
286        heap.insert(20, 1, 2);
287        assert_eq!(20, heap.min());
288        assert!(!heap.is_empty());
289        assert_eq!(heap.len(), 1);
290    }
291
292    #[test]
293    fn heap_sort() {
294        let mut heap = Heap::new();
295
296        let mut input = vec![4, 1, 6, 7, 5];
297        for i in &input {
298            heap.insert(*i, *i, 0);
299        }
300        assert_eq!(1, heap.min());
301        assert!(!heap.is_empty());
302
303        let mut result = Vec::new();
304        while !heap.is_empty() {
305            result.push(heap.delete_min());
306        }
307        assert_eq!(result.len(), 5);
308        assert!(heap.is_empty());
309
310        // Sorting unstable is OK. No observable difference on integers.
311        input.sort_unstable();
312        assert_eq!(result, input);
313    }
314
315    #[test]
316    #[should_panic]
317    fn empty_min_panic() {
318        let heap = Heap::new();
319        heap.min();
320    }
321
322    #[test]
323    fn heap_sort_random() {
324        let mut heap = Heap::new();
325        let mut rng = StdRng::seed_from_u64(0xAAAAAAAA);
326        let mut input = Vec::new();
327
328        for _ in 0..1000 {
329            let number = rng.random();
330            input.push(number);
331            heap.insert(number, number, 0);
332        }
333        assert!(!heap.is_empty());
334        assert_eq!(1000, heap.len());
335        assert_eq!(1000, input.len());
336
337        let mut result = Vec::new();
338        while !heap.is_empty() {
339            result.push(heap.delete_min());
340        }
341        assert_eq!(result.len(), 1000);
342        assert!(heap.is_empty());
343
344        // Sorting unstable is OK. No observable difference on integers.
345        input.sort_unstable();
346        assert_eq!(result, input);
347    }
348
349    #[test]
350    fn clear() {
351        let mut heap = Heap::new();
352        let input = vec![4, 1, 6, 7, 5];
353
354        for i in &input {
355            heap.insert(*i, *i, *i);
356        }
357        assert_eq!(1, heap.min());
358        assert!(!heap.is_empty());
359        assert_eq!(5, heap.len());
360
361        heap.clear();
362        assert_eq!(0, heap.len());
363    }
364
365    #[test]
366    fn data() {
367        let mut heap = Heap::new();
368        let input = vec![4, 1, 6, 7, 5];
369
370        for i in &input {
371            heap.insert(*i, *i, *i);
372        }
373        assert_eq!(1, heap.min());
374        assert!(!heap.is_empty());
375        assert_eq!(5, heap.len());
376
377        for i in &input {
378            assert_eq!(i, heap.data(*i));
379        }
380    }
381
382    #[test]
383    fn data_mut() {
384        let mut heap = Heap::new();
385        let input = vec![4, 1, 6, 7, 5];
386
387        for i in &input {
388            heap.insert(*i, *i, *i);
389        }
390        assert_eq!(1, heap.min());
391        assert!(!heap.is_empty());
392        assert_eq!(5, heap.len());
393
394        // double all data entries
395        for i in &input {
396            let new_value = *heap.data_mut(*i) * 2;
397            *heap.data_mut(*i) = new_value;
398        }
399
400        for i in &input {
401            let new_value = 2 * i;
402            assert_eq!(&new_value, heap.data(*i));
403        }
404    }
405
406    #[test]
407    fn flush() {
408        let mut heap = Heap::default();
409        let input = vec![4, 1, 6, 7, 5];
410
411        for i in &input {
412            heap.insert(*i, *i, *i);
413        }
414        assert_eq!(1, heap.min());
415        assert!(!heap.is_empty());
416        assert_eq!(5, heap.len());
417
418        heap.flush();
419        assert!(heap.is_empty());
420        assert_eq!(0, heap.len());
421    }
422
423    #[test]
424    fn removed() {
425        let mut heap = Heap::default();
426        let input = vec![4, 1, 6, 7, 5];
427
428        for i in &input {
429            heap.insert(*i, *i, *i);
430        }
431        assert_eq!(1, heap.min());
432        assert!(!heap.is_empty());
433        assert_eq!(5, heap.len());
434
435        assert!(!heap.removed(1));
436        assert!(!heap.removed(2));
437        assert!(!heap.removed(3));
438        assert!(!heap.removed(4));
439        assert!(!heap.removed(5));
440        assert!(!heap.removed(6));
441        assert!(!heap.removed(7));
442
443        while !heap.is_empty() {
444            heap.delete_min();
445        }
446
447        assert!(heap.removed(1));
448        assert!(!heap.removed(2));
449        assert!(!heap.removed(3));
450        assert!(heap.removed(4));
451        assert!(heap.removed(5));
452        assert!(heap.removed(6));
453        assert!(heap.removed(7));
454    }
455
456    #[test]
457    fn inserted() {
458        let mut heap = Heap::default();
459        let input = vec![4, 1, 6, 7, 5];
460
461        for i in &input {
462            heap.insert(*i, *i, *i);
463        }
464        assert_eq!(1, heap.min());
465        assert!(!heap.is_empty());
466        assert_eq!(5, heap.len());
467
468        while !heap.is_empty() {
469            heap.delete_min();
470        }
471
472        assert!(heap.inserted(1));
473        assert!(!heap.inserted(2));
474        assert!(!heap.inserted(3));
475        assert!(heap.inserted(4));
476        assert!(heap.inserted(5));
477        assert!(heap.inserted(6));
478        assert!(heap.inserted(7));
479    }
480
481    #[test]
482    fn weight() {
483        let mut heap = Heap::default();
484        let input = vec![4, 1, 6, 7, 5];
485
486        for i in &input {
487            heap.insert(*i, 2 + *i, *i);
488        }
489        assert_eq!(1, heap.min());
490        assert!(!heap.is_empty());
491        assert_eq!(5, heap.len());
492
493        while !heap.is_empty() {
494            let node = heap.delete_min();
495            assert_eq!(heap.weight(node), 2 + node);
496        }
497
498        assert_eq!(heap.weight(1), 2 + 1);
499        assert_eq!(heap.weight(2), i32::MAX);
500        assert_eq!(heap.weight(3), i32::MAX);
501        assert_eq!(heap.weight(4), 2 + 4);
502        assert_eq!(heap.weight(5), 2 + 5);
503        assert_eq!(heap.weight(6), 2 + 6);
504        assert_eq!(heap.weight(7), 2 + 7);
505    }
506
507    #[test]
508    fn decrease_key() {
509        let mut heap = Heap::default();
510        let input = vec![4, 1, 6, 7, 5];
511
512        for i in &input {
513            heap.insert(*i, 2 + *i, *i);
514        }
515        assert_eq!(1, heap.min());
516        assert!(!heap.is_empty());
517        assert_eq!(5, heap.len());
518
519        for i in &input {
520            heap.decrease_key(*i, *i);
521        }
522
523        assert_eq!(heap.weight(1), 1);
524        assert_eq!(heap.weight(2), i32::MAX);
525        assert_eq!(heap.weight(3), i32::MAX);
526        assert_eq!(heap.weight(4), 4);
527        assert_eq!(heap.weight(5), 5);
528        assert_eq!(heap.weight(6), 6);
529        assert_eq!(heap.weight(7), 7);
530    }
531
532    #[test]
533    fn decrease_key_with_new_data() {
534        let mut heap = Heap::default();
535        let input = vec![4, 1, 6, 7, 5];
536
537        for i in &input {
538            heap.insert(*i, 2 + *i, *i);
539        }
540        assert_eq!(heap.inserted_len(), input.len());
541        assert_eq!(1, heap.min());
542        assert!(!heap.is_empty());
543        assert_eq!(5, heap.len());
544
545        for i in &input {
546            heap.decrease_key_and_update_data(*i, *i, i + 10);
547        }
548
549        assert_eq!(heap.weight(1), 1);
550        assert_eq!(*heap.data(1), 11);
551        assert_eq!(heap.weight(2), i32::MAX);
552        assert_eq!(heap.weight(3), i32::MAX);
553        assert_eq!(heap.weight(4), 4);
554        assert_eq!(*heap.data(4), 14);
555        assert_eq!(heap.weight(5), 5);
556        assert_eq!(*heap.data(5), 15);
557        assert_eq!(heap.weight(6), 6);
558        assert_eq!(*heap.data(6), 16);
559        assert_eq!(heap.weight(7), 7);
560        assert_eq!(*heap.data(7), 17);
561    }
562
563    #[test]
564    fn contains() {
565        let mut heap = Heap::default();
566        let input = vec![4, 1, 6, 7, 5];
567
568        for i in &input {
569            heap.insert(*i, *i, *i);
570        }
571
572        // never inserted
573        assert!(!heap.contains(16));
574
575        // rebind list of input values as mutable
576        let mut input = input;
577        input.sort();
578        // rebind list as unmutable again (for good measure)
579        let input = input;
580
581        for i in &input {
582            assert!(heap.contains(*i));
583            let removed = heap.delete_min();
584            assert_eq!(removed, *i);
585            assert!(!heap.contains(*i));
586        }
587    }
588
589    #[test]
590    fn test_pop() {
591        let mut heap = Heap::new();
592
593        // Test empty heap
594        assert_eq!(heap.pop(), None);
595
596        // Test with multiple elements
597        heap.insert(3, 3, 0);
598        heap.insert(1, 1, 0);
599        heap.insert(2, 2, 0);
600
601        assert_eq!(heap.pop(), Some(1)); // Should return min without removing
602        assert_eq!(heap.len(), 3); // Length shouldn't change
603        assert_eq!(heap.pop(), Some(1)); // Should still return same min
604
605        heap.delete_min(); // Actually remove the minimum
606        assert_eq!(heap.pop(), Some(2)); // Should now return new minimum
607        assert_eq!(heap.len(), 2); // Length should be reduced
608    }
609}