Skip to main content

rustsim_pathfinding/
astar.rs

1//! A* pathfinding -- generic and grid-specific variants.
2//!
3//! Provides:
4//! - [`astar`] -- fully generic A* over any graph/node type
5//! - [`astar_grid2d`] -- simple convenience wrapper for 2D grids
6//! - [`astar_grid2d_opts`] -- full-featured grid A* with cost metrics,
7//!   periodic boundaries, admissibility, and walkability maps
8
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap, HashSet};
11
12use crate::metrics::CostMetric;
13
14/// Result of a successful A* search.
15#[derive(Debug, Clone)]
16pub struct AStarResult<N> {
17    /// The sequence of nodes from start to goal (inclusive).
18    pub path: Vec<N>,
19    /// Total cost of the path.
20    pub cost: f64,
21}
22
23#[derive(Clone)]
24struct OpenEntry<N> {
25    node: N,
26    f: f64,
27    g: f64,
28}
29
30impl<N: PartialEq> PartialEq for OpenEntry<N> {
31    fn eq(&self, other: &Self) -> bool {
32        self.node == other.node
33    }
34}
35
36impl<N: Eq> Eq for OpenEntry<N> {}
37
38impl<N: Eq> PartialOrd for OpenEntry<N> {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl<N: Eq> Ord for OpenEntry<N> {
45    fn cmp(&self, other: &Self) -> Ordering {
46        other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
47    }
48}
49
50/// Generic A* pathfinding.
51///
52/// - `start`: starting node
53/// - `goal`: target node
54/// - `heuristic`: estimates cost from a node to the goal (must be admissible)
55/// - `neighbors`: returns neighbors and the edge cost to reach each
56///
57/// Returns `Some(AStarResult)` with the shortest path and total cost,
58/// or `None` if no path exists.
59pub fn astar<N, FH, FN, I>(
60    start: N,
61    goal: N,
62    mut heuristic: FH,
63    mut neighbors: FN,
64) -> Option<AStarResult<N>>
65where
66    N: Clone + Eq + std::hash::Hash,
67    FH: FnMut(&N, &N) -> f64,
68    FN: FnMut(&N) -> I,
69    I: IntoIterator<Item = (N, f64)>,
70{
71    let mut open = BinaryHeap::new();
72    let mut g_scores: HashMap<N, f64> = HashMap::new();
73    let mut came_from: HashMap<N, N> = HashMap::new();
74    let mut closed: HashSet<N> = HashSet::new();
75
76    let h = heuristic(&start, &goal);
77    g_scores.insert(start.clone(), 0.0);
78    open.push(OpenEntry {
79        node: start.clone(),
80        f: h,
81        g: 0.0,
82    });
83
84    while let Some(current) = open.pop() {
85        if current.node == goal {
86            let mut path = Vec::new();
87            let mut cur = goal.clone();
88            loop {
89                path.push(cur.clone());
90                match came_from.get(&cur) {
91                    Some(prev) => cur = prev.clone(),
92                    None => break,
93                }
94            }
95            path.reverse();
96            return Some(AStarResult {
97                path,
98                cost: current.g,
99            });
100        }
101
102        if !closed.insert(current.node.clone()) {
103            continue;
104        }
105
106        for (neighbor, edge_cost) in neighbors(&current.node) {
107            if closed.contains(&neighbor) {
108                continue;
109            }
110
111            let tentative_g = current.g + edge_cost;
112            let prev_g = g_scores.get(&neighbor).copied().unwrap_or(f64::INFINITY);
113
114            if tentative_g < prev_g {
115                g_scores.insert(neighbor.clone(), tentative_g);
116                came_from.insert(neighbor.clone(), current.node.clone());
117                let h = heuristic(&neighbor, &goal);
118                open.push(OpenEntry {
119                    node: neighbor,
120                    f: tentative_g + h,
121                    g: tentative_g,
122                });
123            }
124        }
125    }
126
127    None
128}
129
130/// Convenience: A* on a 2D grid with a walkability map.
131///
132/// Coordinates are `(x, y)` with `0 <= x < width`, `0 <= y < height`.
133/// `walkable` returns `true` if the cell can be entered.
134/// Uses Chebyshev distance as the heuristic (diagonal movement allowed).
135pub fn astar_grid2d(
136    start: (usize, usize),
137    goal: (usize, usize),
138    width: usize,
139    height: usize,
140    walkable: &dyn Fn(usize, usize) -> bool,
141    diagonal: bool,
142) -> Option<AStarResult<(usize, usize)>> {
143    let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
144        let dx = (a.0 as f64 - b.0 as f64).abs();
145        let dy = (a.1 as f64 - b.1 as f64).abs();
146        if diagonal {
147            dx.max(dy) // Chebyshev
148        } else {
149            dx + dy // Manhattan
150        }
151    };
152
153    let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
154        let (x, y) = *node;
155        let mut result = Vec::new();
156
157        let deltas: &[(i32, i32)] = if diagonal {
158            &[
159                (-1, -1),
160                (-1, 0),
161                (-1, 1),
162                (0, -1),
163                (0, 1),
164                (1, -1),
165                (1, 0),
166                (1, 1),
167            ]
168        } else {
169            &[(-1, 0), (1, 0), (0, -1), (0, 1)]
170        };
171
172        for &(dx, dy) in deltas {
173            let nx = x as i32 + dx;
174            let ny = y as i32 + dy;
175            if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
176                let nx = nx as usize;
177                let ny = ny as usize;
178                if walkable(nx, ny) {
179                    let cost = if dx != 0 && dy != 0 {
180                        std::f64::consts::SQRT_2
181                    } else {
182                        1.0
183                    };
184                    result.push(((nx, ny), cost));
185                }
186            }
187        }
188
189        result
190    };
191
192    astar(start, goal, heuristic, neighbors)
193}
194
195/// Options for the full-featured grid A* pathfinder.
196///
197/// Mirrors the keyword arguments of Agents.jl `Pathfinding.AStar`.
198pub struct GridAStarOpts<'a> {
199    /// Grid width.
200    pub width: usize,
201    /// Grid height.
202    pub height: usize,
203    /// Allow diagonal movement (8-connected vs 4-connected).
204    pub diagonal: bool,
205    /// Periodic (toroidal) boundary wrapping.
206    pub periodic: bool,
207    /// Admissibility factor. `0.0` gives optimal paths. Higher values
208    /// allow suboptimal paths for faster search: the algorithm uses
209    /// `f = g + (1 + admissibility) * h`.
210    ///
211    /// Mirrors Agents.jl `admissibility` keyword.
212    pub admissibility: f64,
213    /// Walkability predicate: `walkable(x, y)` returns `true` if the
214    /// cell is passable. Defaults to all-walkable if `None`.
215    pub walkable: Option<&'a dyn Fn(usize, usize) -> bool>,
216    /// Cost metric for heuristic and edge cost estimation.
217    /// Defaults to [`DirectDistance`](crate::metrics::DirectDistance) if `None`.
218    pub cost_metric: Option<&'a dyn CostMetric>,
219}
220
221impl<'a> GridAStarOpts<'a> {
222    /// Create default options for a grid of the given dimensions.
223    pub fn new(width: usize, height: usize) -> Self {
224        Self {
225            width,
226            height,
227            diagonal: true,
228            periodic: false,
229            admissibility: 0.0,
230            walkable: None,
231            cost_metric: None,
232        }
233    }
234}
235
236impl<'a> std::fmt::Debug for GridAStarOpts<'a> {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("GridAStarOpts")
239            .field("width", &self.width)
240            .field("height", &self.height)
241            .field("diagonal", &self.diagonal)
242            .field("periodic", &self.periodic)
243            .field("admissibility", &self.admissibility)
244            .field("walkable", &self.walkable.is_some())
245            .field("cost_metric", &self.cost_metric.is_some())
246            .finish()
247    }
248}
249
250/// Full-featured A* on a 2D grid with cost metrics, periodic boundaries,
251/// admissibility, and walkability maps.
252///
253/// This is the advanced variant of [`astar_grid2d`]. Use [`GridAStarOpts`]
254/// to configure all parameters.
255///
256/// # Admissibility
257///
258/// Setting `admissibility > 0.0` trades optimality for speed. The algorithm
259/// inflates the heuristic by `(1 + admissibility)`, which causes it to
260/// explore fewer nodes at the cost of finding paths that may be up to
261/// `(1 + admissibility)` times the optimal length.
262///
263/// # Example
264///
265/// ```
266/// use rustsim_pathfinding::astar::{astar_grid2d_opts, GridAStarOpts};
267/// use rustsim_pathfinding::metrics::PenaltyMap;
268/// use rustsim_pathfinding::metrics::DirectDistance;
269///
270/// let mut opts = GridAStarOpts::new(50, 50);
271/// opts.diagonal = true;
272/// opts.periodic = true;
273/// opts.admissibility = 0.5;
274/// opts.walkable = Some(&|x, y| !(x == 25 && y != 25));
275///
276/// let penalties = vec![0i32; 2500];
277/// let metric = PenaltyMap::new(penalties, 50, 50, DirectDistance::new()).unwrap();
278/// opts.cost_metric = Some(&metric);
279///
280/// let result = astar_grid2d_opts((0, 0), (49, 49), &opts);
281/// assert!(result.is_some());
282/// ```
283pub fn astar_grid2d_opts(
284    start: (usize, usize),
285    goal: (usize, usize),
286    opts: &GridAStarOpts<'_>,
287) -> Option<AStarResult<(usize, usize)>> {
288    let width = opts.width;
289    let height = opts.height;
290    let diagonal = opts.diagonal;
291    let periodic = opts.periodic;
292    let admissibility = opts.admissibility;
293
294    let default_metric = crate::metrics::DirectDistance::new();
295    let metric: &dyn CostMetric = opts.cost_metric.unwrap_or(&default_metric);
296
297    let always_walkable = |_: usize, _: usize| true;
298    let walkable: &dyn Fn(usize, usize) -> bool = match &opts.walkable {
299        Some(f) => *f,
300        None => &always_walkable,
301    };
302
303    // Verify start and goal are walkable
304    let start_n = normalize_grid_pos(start, periodic, width, height)?;
305    let goal_n = normalize_grid_pos(goal, periodic, width, height)?;
306    if !walkable(start_n.0, start_n.1) || !walkable(goal_n.0, goal_n.1) {
307        return None;
308    }
309
310    let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
311        (1.0 + admissibility) * metric.delta_cost(*a, *b, periodic, width, height, diagonal)
312    };
313
314    let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
315        let (x, y) = *node;
316        let mut result = Vec::new();
317
318        let deltas: &[(i32, i32)] = if diagonal {
319            &[
320                (-1, -1),
321                (-1, 0),
322                (-1, 1),
323                (0, -1),
324                (0, 1),
325                (1, -1),
326                (1, 0),
327                (1, 1),
328            ]
329        } else {
330            &[(-1, 0), (1, 0), (0, -1), (0, 1)]
331        };
332
333        for &(dx, dy) in deltas {
334            let nx = x as i32 + dx;
335            let ny = y as i32 + dy;
336
337            let neighbor = if periodic {
338                let px = ((nx % width as i32) + width as i32) % width as i32;
339                let py = ((ny % height as i32) + height as i32) % height as i32;
340                Some((px as usize, py as usize))
341            } else if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
342                Some((nx as usize, ny as usize))
343            } else {
344                None
345            };
346
347            if let Some(n) = neighbor {
348                if walkable(n.0, n.1) {
349                    let cost = metric.delta_cost(*node, n, periodic, width, height, diagonal);
350                    result.push((n, cost));
351                }
352            }
353        }
354
355        result
356    };
357
358    astar(start_n, goal_n, heuristic, neighbors)
359}
360
361fn normalize_grid_pos(
362    pos: (usize, usize),
363    periodic: bool,
364    width: usize,
365    height: usize,
366) -> Option<(usize, usize)> {
367    if periodic {
368        Some((pos.0 % width, pos.1 % height))
369    } else if pos.0 < width && pos.1 < height {
370        Some(pos)
371    } else {
372        None
373    }
374}