Skip to main content

uni_algo/algo/algorithms/
astar.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! A* Search Algorithm.
5//!
6//! A* is an informed search algorithm, or a best-first search, meaning that it is formulated in terms of weighted graphs:
7//! starting from a specific starting node of a graph, it aims to find a path to the given goal node having the smallest cost.
8//! It uses a heuristic function `h(n)` to estimate the cost from node `n` to the goal.
9
10use crate::algo::GraphProjection;
11use crate::algo::algorithms::Algorithm;
12use std::cmp::Reverse;
13use std::collections::{BinaryHeap, HashMap};
14use uni_common::core::id::Vid;
15
16pub struct AStar;
17
18#[derive(Debug, Clone)]
19pub struct AStarConfig {
20    pub source: Vid,
21    pub target: Vid,
22    /// Heuristic values h(n) for each node n.
23    /// If a node is missing, h(n) is assumed to be 0.0.
24    pub heuristic: HashMap<Vid, f64>,
25}
26
27impl Default for AStarConfig {
28    fn default() -> Self {
29        Self {
30            source: Vid::from(0),
31            target: Vid::from(0),
32            heuristic: HashMap::new(),
33        }
34    }
35}
36
37pub struct AStarResult {
38    pub distance: Option<f64>,
39    pub path: Option<Vec<Vid>>,
40    pub visited_count: usize,
41}
42
43impl Algorithm for AStar {
44    type Config = AStarConfig;
45    type Result = AStarResult;
46
47    fn name() -> &'static str {
48        "astar"
49    }
50
51    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
52        let source_slot = match graph.to_slot(config.source) {
53            Some(slot) => slot,
54            None => {
55                return AStarResult {
56                    distance: None,
57                    path: None,
58                    visited_count: 0,
59                };
60            }
61        };
62
63        let target_slot = match graph.to_slot(config.target) {
64            Some(slot) => slot,
65            None => {
66                return AStarResult {
67                    distance: None,
68                    path: None,
69                    visited_count: 0,
70                };
71            }
72        };
73
74        let n = graph.vertex_count();
75        // g_score[u] is the cost of the cheapest path from source to u currently known.
76        let mut g_score = vec![f64::INFINITY; n];
77        let mut prev: Vec<Option<u32>> = vec![None; n];
78
79        // Priority queue stores (f_score, u), where f_score = g_score[u] + h(u).
80        // We use Reverse for min-heap behavior.
81        // Storing bits for f64 ordering.
82        let mut heap = BinaryHeap::new();
83
84        g_score[source_slot as usize] = 0.0;
85        let h_source = config.heuristic.get(&config.source).copied().unwrap_or(0.0);
86        let f_source = 0.0 + h_source;
87
88        heap.push(Reverse((f_source.to_bits(), source_slot)));
89
90        let mut visited_count = 0;
91
92        while let Some(Reverse((f_bits, u))) = heap.pop() {
93            let f_current = f64::from_bits(f_bits);
94
95            // If we reached the target, we are done (if heuristic is consistent/monotone).
96            // A* with consistent heuristic guarantees optimal path first time target is popped.
97            if u == target_slot {
98                let dist = g_score[u as usize];
99
100                // Reconstruct path
101                let mut path = Vec::new();
102                let mut curr = Some(u);
103                while let Some(slot) = curr {
104                    path.push(graph.to_vid(slot));
105                    if slot == source_slot {
106                        break;
107                    }
108                    curr = prev[slot as usize];
109                }
110                path.reverse();
111
112                return AStarResult {
113                    distance: Some(dist),
114                    path: Some(path),
115                    visited_count,
116                };
117            }
118
119            // Lazy deletion / check if we found a better g_score already that makes this entry stale
120            // f_current is the f_score stored in heap.
121            // Current best f would be g_score[u] + h(u).
122            // If f_current > g_score[u] + h(u), then this heap entry is stale.
123            let h_u = config
124                .heuristic
125                .get(&graph.to_vid(u))
126                .copied()
127                .unwrap_or(0.0);
128            if f_current > g_score[u as usize] + h_u {
129                continue;
130            }
131
132            visited_count += 1;
133
134            for (i, &v) in graph.out_neighbors(u).iter().enumerate() {
135                let weight = if graph.has_weights() {
136                    graph.out_weight(u, i)
137                } else {
138                    1.0
139                };
140
141                let tentative_g = g_score[u as usize] + weight;
142
143                if tentative_g < g_score[v as usize] {
144                    prev[v as usize] = Some(u);
145                    g_score[v as usize] = tentative_g;
146
147                    let h_v = config
148                        .heuristic
149                        .get(&graph.to_vid(v))
150                        .copied()
151                        .unwrap_or(0.0);
152                    let f_v = tentative_g + h_v;
153
154                    heap.push(Reverse((f_v.to_bits(), v)));
155                }
156            }
157        }
158
159        // Path not found
160        AStarResult {
161            distance: None,
162            path: None,
163            visited_count,
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::algo::test_utils::build_test_graph;
172
173    #[test]
174    fn test_astar_simple() {
175        // 0 -> 1 -> 2
176        // Weights 1.0
177        // Heuristic: h(0)=2, h(1)=1, h(2)=0
178        let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2)];
179        let edges = vec![(Vid::from(0), Vid::from(1)), (Vid::from(1), Vid::from(2))];
180        let graph = build_test_graph(vids, edges); // weights default to 1.0 if not set?
181        // GraphProjection out_weights is Option. build_test_graph sets None.
182        // AStar assumes 1.0 if no weights.
183
184        let mut heuristic = HashMap::new();
185        heuristic.insert(Vid::from(0), 2.0);
186        heuristic.insert(Vid::from(1), 1.0);
187        heuristic.insert(Vid::from(2), 0.0);
188
189        let config = AStarConfig {
190            source: Vid::from(0),
191            target: Vid::from(2),
192            heuristic,
193        };
194
195        let result = AStar::run(&graph, config);
196
197        assert_eq!(result.distance, Some(2.0));
198        assert!(result.path.is_some());
199        let path = result.path.unwrap();
200        assert_eq!(path, vec![Vid::from(0), Vid::from(1), Vid::from(2)]);
201    }
202
203    #[test]
204    fn test_astar_heuristic_guides() {
205        // 0 -> 1 -> 3 (cost 1+1=2)
206        // 0 -> 2 -> 3 (cost 1+10=11)
207        // Heuristic favors 2? h(1)=100, h(2)=0
208        // A* should still find optimal path 0->1->3 even if heuristic initially guides to 2?
209        // Actually, if heuristic is admissible (h(n) <= true cost), it finds optimal.
210        // If h(1)=100, it's > true cost (1), so not admissible.
211        // But let's check if it explores 2 first.
212
213        let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2), Vid::from(3)];
214        let edges = vec![
215            (Vid::from(0), Vid::from(1)),
216            (Vid::from(1), Vid::from(3)),
217            (Vid::from(0), Vid::from(2)),
218            (Vid::from(2), Vid::from(3)),
219        ];
220
221        // We need weighted graph for this test to be interesting, but build_test_graph is unweighted.
222        // With unweighted, both paths are len 2.
223        // Let's use heuristic to guide order.
224
225        let graph = build_test_graph(vids, edges);
226
227        // h(1) = 0.5, h(2) = 0.1
228        // f(1) = g(1)+h(1) = 1 + 0.5 = 1.5
229        // f(2) = g(2)+h(2) = 1 + 0.1 = 1.1
230        // Should expand 2 first.
231
232        let mut heuristic = HashMap::new();
233        heuristic.insert(Vid::from(1), 0.5);
234        heuristic.insert(Vid::from(2), 0.1);
235        heuristic.insert(Vid::from(3), 0.0);
236
237        let config = AStarConfig {
238            source: Vid::from(0),
239            target: Vid::from(3),
240            heuristic,
241        };
242
243        let result = AStar::run(&graph, config);
244        assert_eq!(result.distance, Some(2.0));
245        // Path could be either, but A* with correct heuristic finds optimal.
246    }
247}