1use crate::bridge::OccupancyGrid;
11
12use serde::{Deserialize, Serialize};
13use std::cmp::Ordering;
14use std::collections::{BinaryHeap, HashMap, HashSet};
15
16pub type Cell = (usize, usize);
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct GridPath {
22 pub cells: Vec<Cell>,
24 pub cost: f64,
26}
27
28#[derive(Debug, thiserror::Error)]
30pub enum PlanningError {
31 #[error("start cell ({0}, {1}) is out of bounds or occupied")]
32 InvalidStart(usize, usize),
33 #[error("goal cell ({0}, {1}) is out of bounds or occupied")]
34 InvalidGoal(usize, usize),
35 #[error("no feasible path found")]
36 NoPath,
37}
38
39pub type Result<T> = std::result::Result<T, PlanningError>;
40
41const OCCUPIED_THRESHOLD: f32 = 0.5;
47
48#[derive(PartialEq)]
49struct AStarEntry {
50 cell: Cell,
51 f: f64,
52}
53impl Eq for AStarEntry {}
54impl Ord for AStarEntry {
55 fn cmp(&self, other: &Self) -> Ordering {
56 other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
57 }
58}
59impl PartialOrd for AStarEntry {
60 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 Some(self.cmp(other))
62 }
63}
64
65pub fn astar(
70 grid: &OccupancyGrid,
71 start: Cell,
72 goal: Cell,
73) -> Result<GridPath> {
74 if !cell_free(grid, start) {
75 return Err(PlanningError::InvalidStart(start.0, start.1));
76 }
77 if !cell_free(grid, goal) {
78 return Err(PlanningError::InvalidGoal(goal.0, goal.1));
79 }
80 if start == goal {
81 return Ok(GridPath { cells: vec![start], cost: 0.0 });
82 }
83
84 let mut g_score: HashMap<Cell, f64> = HashMap::with_capacity(128);
85 let mut came_from: HashMap<Cell, Cell> = HashMap::with_capacity(128);
86 let mut open = BinaryHeap::new();
87 let mut closed: HashSet<Cell> = HashSet::with_capacity(128);
88 let mut neighbor_buf: Vec<(usize, usize, f64)> = Vec::with_capacity(8);
89
90 g_score.insert(start, 0.0);
91 open.push(AStarEntry { cell: start, f: heuristic(start, goal) });
92
93 while let Some(AStarEntry { cell, .. }) = open.pop() {
94 if cell == goal {
95 return Ok(reconstruct_path(&came_from, goal, &g_score));
96 }
97
98 if !closed.insert(cell) {
100 continue;
101 }
102
103 let current_g = g_score[&cell];
104
105 neighbors_into(grid, cell, &mut neighbor_buf);
106 for &(nx, ny, step_cost) in &neighbor_buf {
107 let neighbor = (nx, ny);
108 if closed.contains(&neighbor) {
109 continue;
110 }
111 let tentative_g = current_g + step_cost;
112 if tentative_g < *g_score.get(&neighbor).unwrap_or(&f64::INFINITY) {
113 g_score.insert(neighbor, tentative_g);
114 came_from.insert(neighbor, cell);
115 open.push(AStarEntry {
116 cell: neighbor,
117 f: tentative_g + heuristic(neighbor, goal),
118 });
119 }
120 }
121 }
122
123 Err(PlanningError::NoPath)
124}
125
126#[inline]
127fn cell_free(grid: &OccupancyGrid, (x, y): Cell) -> bool {
128 grid.get(x, y).is_some_and(|v| v < OCCUPIED_THRESHOLD)
129}
130
131#[inline]
132fn heuristic(a: Cell, b: Cell) -> f64 {
133 let dx = (a.0 as f64 - b.0 as f64).abs();
134 let dy = (a.1 as f64 - b.1 as f64).abs();
135 let (min, max) = if dx < dy { (dx, dy) } else { (dy, dx) };
137 min * std::f64::consts::SQRT_2 + (max - min)
138}
139
140#[inline]
143fn neighbors_into(grid: &OccupancyGrid, (cx, cy): Cell, out: &mut Vec<(usize, usize, f64)>) {
144 out.clear();
145 for dx in [-1_i64, 0, 1] {
146 for dy in [-1_i64, 0, 1] {
147 if dx == 0 && dy == 0 {
148 continue;
149 }
150 let nx = cx as i64 + dx;
151 let ny = cy as i64 + dy;
152 if nx < 0 || ny < 0 {
153 continue;
154 }
155 let (nx, ny) = (nx as usize, ny as usize);
156 if cell_free(grid, (nx, ny)) {
157 let cost = if dx != 0 && dy != 0 {
158 std::f64::consts::SQRT_2
159 } else {
160 1.0
161 };
162 out.push((nx, ny, cost));
163 }
164 }
165 }
166}
167
168fn reconstruct_path(
169 came_from: &HashMap<Cell, Cell>,
170 goal: Cell,
171 g_score: &HashMap<Cell, f64>,
172) -> GridPath {
173 let mut cells = vec![goal];
174 let mut current = goal;
175 while let Some(&prev) = came_from.get(¤t) {
176 cells.push(prev);
177 current = prev;
178 }
179 cells.reverse();
180 let cost = g_score.get(&goal).copied().unwrap_or(0.0);
181 GridPath { cells, cost }
182}
183
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
190pub struct VelocityCommand {
191 pub vx: f64,
192 pub vy: f64,
193 pub vz: f64,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct PotentialFieldConfig {
199 pub attractive_gain: f64,
201 pub repulsive_gain: f64,
203 pub obstacle_influence: f64,
205 pub max_speed: f64,
207}
208
209impl Default for PotentialFieldConfig {
210 fn default() -> Self {
211 Self {
212 attractive_gain: 1.0,
213 repulsive_gain: 100.0,
214 obstacle_influence: 3.0,
215 max_speed: 2.0,
216 }
217 }
218}
219
220pub fn potential_field(
226 robot: &[f64; 3],
227 goal: &[f64; 3],
228 obstacles: &[[f64; 3]],
229 config: &PotentialFieldConfig,
230) -> VelocityCommand {
231 let mut fx = config.attractive_gain * (goal[0] - robot[0]);
233 let mut fy = config.attractive_gain * (goal[1] - robot[1]);
234 let mut fz = config.attractive_gain * (goal[2] - robot[2]);
235
236 for obs in obstacles {
238 let dx = robot[0] - obs[0];
239 let dy = robot[1] - obs[1];
240 let dz = robot[2] - obs[2];
241 let dist = (dx * dx + dy * dy + dz * dz).sqrt().max(0.01);
242
243 if dist < config.obstacle_influence {
244 let strength =
245 config.repulsive_gain * (1.0 / dist - 1.0 / config.obstacle_influence) / (dist * dist);
246 fx += strength * dx / dist;
247 fy += strength * dy / dist;
248 fz += strength * dz / dist;
249 }
250 }
251
252 let speed = (fx * fx + fy * fy + fz * fz).sqrt();
254 if speed > config.max_speed {
255 let s = config.max_speed / speed;
256 fx *= s;
257 fy *= s;
258 fz *= s;
259 }
260
261 VelocityCommand { vx: fx, vy: fy, vz: fz }
262}
263
264pub fn path_to_waypoints(path: &GridPath, resolution: f64, origin: &[f64; 3]) -> Vec<[f64; 3]> {
271 path.cells
272 .iter()
273 .map(|&(x, y)| {
274 [
275 origin[0] + x as f64 * resolution,
276 origin[1] + y as f64 * resolution,
277 origin[2],
278 ]
279 })
280 .collect()
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 fn free_grid(w: usize, h: usize) -> OccupancyGrid {
288 OccupancyGrid::new(w, h, 1.0)
289 }
290
291 #[test]
292 fn test_astar_straight_line() {
293 let grid = free_grid(10, 10);
294 let path = astar(&grid, (0, 0), (5, 0)).unwrap();
295 assert_eq!(*path.cells.first().unwrap(), (0, 0));
296 assert_eq!(*path.cells.last().unwrap(), (5, 0));
297 assert!((path.cost - 5.0).abs() < 1e-6);
298 }
299
300 #[test]
301 fn test_astar_diagonal() {
302 let grid = free_grid(10, 10);
303 let path = astar(&grid, (0, 0), (3, 3)).unwrap();
304 assert_eq!(*path.cells.last().unwrap(), (3, 3));
305 assert!((path.cost - 3.0 * std::f64::consts::SQRT_2).abs() < 1e-6);
307 }
308
309 #[test]
310 fn test_astar_same_cell() {
311 let grid = free_grid(5, 5);
312 let path = astar(&grid, (2, 2), (2, 2)).unwrap();
313 assert_eq!(path.cells.len(), 1);
314 assert!((path.cost).abs() < 1e-9);
315 }
316
317 #[test]
318 fn test_astar_around_wall() {
319 let mut grid = free_grid(10, 10);
320 for y in 0..5 {
322 grid.set(3, y, 1.0);
323 }
324 let path = astar(&grid, (1, 2), (5, 2)).unwrap();
325 assert_eq!(*path.cells.last().unwrap(), (5, 2));
326 assert!(path.cost > 4.0);
328 }
329
330 #[test]
331 fn test_astar_blocked() {
332 let mut grid = free_grid(5, 5);
333 for y in 0..5 {
335 grid.set(2, y, 1.0);
336 }
337 let result = astar(&grid, (0, 2), (4, 2));
338 assert!(result.is_err());
339 }
340
341 #[test]
342 fn test_astar_invalid_start() {
343 let grid = free_grid(5, 5);
344 let result = astar(&grid, (10, 10), (2, 2));
345 assert!(result.is_err());
346 }
347
348 #[test]
349 fn test_potential_field_towards_goal() {
350 let cmd = potential_field(
351 &[0.0, 0.0, 0.0],
352 &[5.0, 0.0, 0.0],
353 &[],
354 &PotentialFieldConfig::default(),
355 );
356 assert!(cmd.vx > 0.0);
357 assert!(cmd.vy.abs() < 1e-9);
358 }
359
360 #[test]
361 fn test_potential_field_obstacle_repulsion() {
362 let cmd = potential_field(
363 &[0.0, 0.0, 0.0],
364 &[5.0, 0.0, 0.0],
365 &[[1.0, 0.0, 0.0]],
366 &PotentialFieldConfig::default(),
367 );
368 let cmd_no_obs = potential_field(
370 &[0.0, 0.0, 0.0],
371 &[5.0, 0.0, 0.0],
372 &[],
373 &PotentialFieldConfig::default(),
374 );
375 assert!(cmd.vx < cmd_no_obs.vx);
376 }
377
378 #[test]
379 fn test_potential_field_max_speed() {
380 let config = PotentialFieldConfig { max_speed: 1.0, ..Default::default() };
381 let cmd = potential_field(
382 &[0.0, 0.0, 0.0],
383 &[100.0, 100.0, 0.0],
384 &[],
385 &config,
386 );
387 let speed = (cmd.vx * cmd.vx + cmd.vy * cmd.vy + cmd.vz * cmd.vz).sqrt();
388 assert!((speed - 1.0).abs() < 1e-9);
389 }
390
391 #[test]
392 fn test_path_to_waypoints() {
393 let path = GridPath {
394 cells: vec![(0, 0), (1, 0), (2, 0)],
395 cost: 2.0,
396 };
397 let wps = path_to_waypoints(&path, 0.5, &[0.0, 0.0, 0.0]);
398 assert_eq!(wps.len(), 3);
399 assert!((wps[1][0] - 0.5).abs() < 1e-9);
400 assert!((wps[2][0] - 1.0).abs() < 1e-9);
401 }
402}