uni_algo/algo/algorithms/
astar.rs1use crate::algo::GraphProjection;
11use crate::algo::algorithms::Algorithm;
12use std::cmp::Reverse;
13use std::collections::{BinaryHeap, HashMap};
14use uni_common::core::id::Vid;
15
16pub struct AStar;
17
18#[derive(Debug, Clone)]
19pub struct AStarConfig {
20 pub source: Vid,
21 pub target: Vid,
22 pub heuristic: HashMap<Vid, f64>,
25}
26
27impl Default for AStarConfig {
28 fn default() -> Self {
29 Self {
30 source: Vid::from(0),
31 target: Vid::from(0),
32 heuristic: HashMap::new(),
33 }
34 }
35}
36
37pub struct AStarResult {
38 pub distance: Option<f64>,
39 pub path: Option<Vec<Vid>>,
40 pub visited_count: usize,
41}
42
43impl Algorithm for AStar {
44 type Config = AStarConfig;
45 type Result = AStarResult;
46
47 fn name() -> &'static str {
48 "astar"
49 }
50
51 fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
52 let source_slot = match graph.to_slot(config.source) {
53 Some(slot) => slot,
54 None => {
55 return AStarResult {
56 distance: None,
57 path: None,
58 visited_count: 0,
59 };
60 }
61 };
62
63 let target_slot = match graph.to_slot(config.target) {
64 Some(slot) => slot,
65 None => {
66 return AStarResult {
67 distance: None,
68 path: None,
69 visited_count: 0,
70 };
71 }
72 };
73
74 let n = graph.vertex_count();
75 let mut g_score = vec![f64::INFINITY; n];
77 let mut prev: Vec<Option<u32>> = vec![None; n];
78
79 let mut heap = BinaryHeap::new();
83
84 g_score[source_slot as usize] = 0.0;
85 let h_source = config.heuristic.get(&config.source).copied().unwrap_or(0.0);
86 let f_source = 0.0 + h_source;
87
88 heap.push(Reverse((f_source.to_bits(), source_slot)));
89
90 let mut visited_count = 0;
91
92 while let Some(Reverse((f_bits, u))) = heap.pop() {
93 let f_current = f64::from_bits(f_bits);
94
95 if u == target_slot {
98 let dist = g_score[u as usize];
99
100 let mut path = Vec::new();
102 let mut curr = Some(u);
103 while let Some(slot) = curr {
104 path.push(graph.to_vid(slot));
105 if slot == source_slot {
106 break;
107 }
108 curr = prev[slot as usize];
109 }
110 path.reverse();
111
112 return AStarResult {
113 distance: Some(dist),
114 path: Some(path),
115 visited_count,
116 };
117 }
118
119 let h_u = config
124 .heuristic
125 .get(&graph.to_vid(u))
126 .copied()
127 .unwrap_or(0.0);
128 if f_current > g_score[u as usize] + h_u {
129 continue;
130 }
131
132 visited_count += 1;
133
134 for (i, &v) in graph.out_neighbors(u).iter().enumerate() {
135 let weight = if graph.has_weights() {
136 graph.out_weight(u, i)
137 } else {
138 1.0
139 };
140
141 let tentative_g = g_score[u as usize] + weight;
142
143 if tentative_g < g_score[v as usize] {
144 prev[v as usize] = Some(u);
145 g_score[v as usize] = tentative_g;
146
147 let h_v = config
148 .heuristic
149 .get(&graph.to_vid(v))
150 .copied()
151 .unwrap_or(0.0);
152 let f_v = tentative_g + h_v;
153
154 heap.push(Reverse((f_v.to_bits(), v)));
155 }
156 }
157 }
158
159 AStarResult {
161 distance: None,
162 path: None,
163 visited_count,
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::algo::test_utils::build_test_graph;
172
173 #[test]
174 fn test_astar_simple() {
175 let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2)];
179 let edges = vec![(Vid::from(0), Vid::from(1)), (Vid::from(1), Vid::from(2))];
180 let graph = build_test_graph(vids, edges); let mut heuristic = HashMap::new();
185 heuristic.insert(Vid::from(0), 2.0);
186 heuristic.insert(Vid::from(1), 1.0);
187 heuristic.insert(Vid::from(2), 0.0);
188
189 let config = AStarConfig {
190 source: Vid::from(0),
191 target: Vid::from(2),
192 heuristic,
193 };
194
195 let result = AStar::run(&graph, config);
196
197 assert_eq!(result.distance, Some(2.0));
198 assert!(result.path.is_some());
199 let path = result.path.unwrap();
200 assert_eq!(path, vec![Vid::from(0), Vid::from(1), Vid::from(2)]);
201 }
202
203 #[test]
204 fn test_astar_heuristic_guides() {
205 let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2), Vid::from(3)];
214 let edges = vec![
215 (Vid::from(0), Vid::from(1)),
216 (Vid::from(1), Vid::from(3)),
217 (Vid::from(0), Vid::from(2)),
218 (Vid::from(2), Vid::from(3)),
219 ];
220
221 let graph = build_test_graph(vids, edges);
226
227 let mut heuristic = HashMap::new();
233 heuristic.insert(Vid::from(1), 0.5);
234 heuristic.insert(Vid::from(2), 0.1);
235 heuristic.insert(Vid::from(3), 0.0);
236
237 let config = AStarConfig {
238 source: Vid::from(0),
239 target: Vid::from(3),
240 heuristic,
241 };
242
243 let result = AStar::run(&graph, config);
244 assert_eq!(result.distance, Some(2.0));
245 }
247}