Skip to main content

waremax_map/
routing.rs

1//! Routing algorithms for finding paths in the warehouse map
2
3use crate::graph::WarehouseMap;
4use crate::traffic::TrafficManager;
5use std::cmp::Ordering;
6use std::collections::{BinaryHeap, HashMap, HashSet};
7use waremax_core::{EdgeId, NodeId};
8
9/// v1: Routing algorithm selection
10#[derive(Clone, Debug, Default, PartialEq)]
11pub enum RoutingAlgorithm {
12    #[default]
13    Dijkstra,
14    AStar,
15}
16
17/// A computed route through the warehouse
18#[derive(Clone, Debug)]
19pub struct Route {
20    pub path: Vec<NodeId>,
21    pub total_distance: f64,
22}
23
24impl Route {
25    pub fn empty(start: NodeId) -> Self {
26        Self {
27            path: vec![start],
28            total_distance: 0.0,
29        }
30    }
31
32    pub fn is_empty(&self) -> bool {
33        self.path.len() <= 1
34    }
35
36    pub fn len(&self) -> usize {
37        self.path.len()
38    }
39}
40
41/// Cache for computed routes
42pub struct RouteCache {
43    cache: HashMap<(NodeId, NodeId), Route>,
44    max_size: usize,
45}
46
47impl RouteCache {
48    pub fn new(max_size: usize) -> Self {
49        Self {
50            cache: HashMap::new(),
51            max_size,
52        }
53    }
54
55    pub fn get(&self, from: NodeId, to: NodeId) -> Option<&Route> {
56        self.cache.get(&(from, to))
57    }
58
59    pub fn insert(&mut self, from: NodeId, to: NodeId, route: Route) {
60        if self.cache.len() >= self.max_size {
61            let keys: Vec<_> = self.cache.keys().take(self.max_size / 2).copied().collect();
62            for key in keys {
63                self.cache.remove(&key);
64            }
65        }
66        self.cache.insert((from, to), route);
67    }
68
69    pub fn invalidate(&mut self) {
70        self.cache.clear();
71    }
72}
73
74/// Router for finding paths in the warehouse
75pub struct Router {
76    cache: RouteCache,
77    cache_enabled: bool,
78    /// v1: Routing algorithm
79    algorithm: RoutingAlgorithm,
80    /// v1: Congestion weight for congestion-aware routing (0.0 = no penalty)
81    congestion_weight: f64,
82}
83
84impl Router {
85    pub fn new(cache_enabled: bool) -> Self {
86        Self {
87            cache: RouteCache::new(10000),
88            cache_enabled,
89            algorithm: RoutingAlgorithm::default(),
90            congestion_weight: 0.0,
91        }
92    }
93
94    /// v1: Create router with specific algorithm
95    pub fn with_algorithm(cache_enabled: bool, algorithm: RoutingAlgorithm) -> Self {
96        Self {
97            cache: RouteCache::new(10000),
98            cache_enabled,
99            algorithm,
100            congestion_weight: 0.0,
101        }
102    }
103
104    /// v1: Set congestion weight for congestion-aware routing
105    pub fn set_congestion_weight(&mut self, weight: f64) {
106        self.congestion_weight = weight;
107    }
108
109    /// v1: Get current routing algorithm
110    pub fn algorithm(&self) -> &RoutingAlgorithm {
111        &self.algorithm
112    }
113
114    pub fn find_route(&mut self, map: &WarehouseMap, from: NodeId, to: NodeId) -> Option<Route> {
115        if from == to {
116            return Some(Route::empty(from));
117        }
118
119        if self.cache_enabled {
120            if let Some(route) = self.cache.get(from, to) {
121                return Some(route.clone());
122            }
123        }
124
125        let route = match self.algorithm {
126            RoutingAlgorithm::Dijkstra => self.dijkstra(map, from, to, None),
127            RoutingAlgorithm::AStar => self.astar(map, from, to, None),
128        }?;
129
130        if self.cache_enabled {
131            self.cache.insert(from, to, route.clone());
132        }
133
134        Some(route)
135    }
136
137    /// v1: Find route with congestion awareness
138    pub fn find_route_with_traffic(
139        &mut self,
140        map: &WarehouseMap,
141        from: NodeId,
142        to: NodeId,
143        traffic: &TrafficManager,
144    ) -> Option<Route> {
145        if from == to {
146            return Some(Route::empty(from));
147        }
148
149        // Don't use cache when congestion-aware (traffic state changes)
150        match self.algorithm {
151            RoutingAlgorithm::Dijkstra => self.dijkstra(map, from, to, Some(traffic)),
152            RoutingAlgorithm::AStar => self.astar(map, from, to, Some(traffic)),
153        }
154    }
155
156    /// v1: Find route avoiding specific edges
157    pub fn find_route_avoiding(
158        &mut self,
159        map: &WarehouseMap,
160        from: NodeId,
161        to: NodeId,
162        avoid_edges: &[EdgeId],
163        traffic: Option<&TrafficManager>,
164    ) -> Option<Route> {
165        if from == to {
166            return Some(Route::empty(from));
167        }
168
169        let avoid_set: HashSet<EdgeId> = avoid_edges.iter().copied().collect();
170        self.dijkstra_avoiding(map, from, to, &avoid_set, traffic)
171    }
172
173    /// Calculate edge cost with optional congestion penalty and speed multiplier
174    fn edge_cost(
175        &self,
176        map: &WarehouseMap,
177        length: f64,
178        edge_id: EdgeId,
179        traffic: Option<&TrafficManager>,
180    ) -> f64 {
181        // Apply speed multiplier (v2: fast lanes/express paths)
182        let speed_multiplier = map
183            .get_edge(edge_id)
184            .map(|e| e.speed_multiplier)
185            .unwrap_or(1.0);
186        let base_cost = length * speed_multiplier;
187
188        // Apply congestion penalty if enabled
189        if self.congestion_weight > 0.0 {
190            if let Some(tm) = traffic {
191                let occupancy = tm.get_edge_occupancy(edge_id);
192                return base_cost * (1.0 + self.congestion_weight * occupancy as f64);
193            }
194        }
195        base_cost
196    }
197
198    fn dijkstra(
199        &self,
200        map: &WarehouseMap,
201        from: NodeId,
202        to: NodeId,
203        traffic: Option<&TrafficManager>,
204    ) -> Option<Route> {
205        #[derive(Clone, PartialEq)]
206        struct State {
207            cost: f64,
208            node: NodeId,
209        }
210
211        impl Eq for State {}
212
213        impl Ord for State {
214            fn cmp(&self, other: &Self) -> Ordering {
215                other
216                    .cost
217                    .partial_cmp(&self.cost)
218                    .unwrap_or(Ordering::Equal)
219            }
220        }
221
222        impl PartialOrd for State {
223            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
224                Some(self.cmp(other))
225            }
226        }
227
228        let mut dist: HashMap<NodeId, f64> = HashMap::new();
229        let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
230        let mut heap = BinaryHeap::new();
231
232        dist.insert(from, 0.0);
233        heap.push(State {
234            cost: 0.0,
235            node: from,
236        });
237
238        while let Some(State { cost, node }) = heap.pop() {
239            if node == to {
240                let mut path = vec![to];
241                let mut current = to;
242
243                while let Some(&prev_node) = prev.get(&current) {
244                    path.push(prev_node);
245                    current = prev_node;
246                }
247
248                path.reverse();
249                return Some(Route {
250                    path,
251                    total_distance: cost,
252                });
253            }
254
255            if let Some(&d) = dist.get(&node) {
256                if cost > d {
257                    continue;
258                }
259            }
260
261            for (neighbor, edge_id, length) in map.neighbors(node) {
262                let edge_cost = self.edge_cost(map, length, edge_id, traffic);
263                let next_cost = cost + edge_cost;
264
265                if dist.get(&neighbor).is_none_or(|&d| next_cost < d) {
266                    dist.insert(neighbor, next_cost);
267                    prev.insert(neighbor, node);
268                    heap.push(State {
269                        cost: next_cost,
270                        node: neighbor,
271                    });
272                }
273            }
274        }
275
276        None
277    }
278
279    /// v1: Dijkstra avoiding specific edges
280    fn dijkstra_avoiding(
281        &self,
282        map: &WarehouseMap,
283        from: NodeId,
284        to: NodeId,
285        avoid_edges: &HashSet<EdgeId>,
286        traffic: Option<&TrafficManager>,
287    ) -> Option<Route> {
288        #[derive(Clone, PartialEq)]
289        struct State {
290            cost: f64,
291            node: NodeId,
292        }
293
294        impl Eq for State {}
295
296        impl Ord for State {
297            fn cmp(&self, other: &Self) -> Ordering {
298                other
299                    .cost
300                    .partial_cmp(&self.cost)
301                    .unwrap_or(Ordering::Equal)
302            }
303        }
304
305        impl PartialOrd for State {
306            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
307                Some(self.cmp(other))
308            }
309        }
310
311        let mut dist: HashMap<NodeId, f64> = HashMap::new();
312        let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
313        let mut heap = BinaryHeap::new();
314
315        dist.insert(from, 0.0);
316        heap.push(State {
317            cost: 0.0,
318            node: from,
319        });
320
321        while let Some(State { cost, node }) = heap.pop() {
322            if node == to {
323                let mut path = vec![to];
324                let mut current = to;
325
326                while let Some(&prev_node) = prev.get(&current) {
327                    path.push(prev_node);
328                    current = prev_node;
329                }
330
331                path.reverse();
332                return Some(Route {
333                    path,
334                    total_distance: cost,
335                });
336            }
337
338            if let Some(&d) = dist.get(&node) {
339                if cost > d {
340                    continue;
341                }
342            }
343
344            for (neighbor, edge_id, length) in map.neighbors(node) {
345                // Skip avoided edges
346                if avoid_edges.contains(&edge_id) {
347                    continue;
348                }
349
350                let edge_cost = self.edge_cost(map, length, edge_id, traffic);
351                let next_cost = cost + edge_cost;
352
353                if dist.get(&neighbor).is_none_or(|&d| next_cost < d) {
354                    dist.insert(neighbor, next_cost);
355                    prev.insert(neighbor, node);
356                    heap.push(State {
357                        cost: next_cost,
358                        node: neighbor,
359                    });
360                }
361            }
362        }
363
364        None
365    }
366
367    /// v1: A* algorithm with euclidean heuristic
368    fn astar(
369        &self,
370        map: &WarehouseMap,
371        from: NodeId,
372        to: NodeId,
373        traffic: Option<&TrafficManager>,
374    ) -> Option<Route> {
375        #[derive(Clone, PartialEq)]
376        struct State {
377            f_cost: f64, // f = g + h
378            g_cost: f64, // actual cost from start
379            node: NodeId,
380        }
381
382        impl Eq for State {}
383
384        impl Ord for State {
385            fn cmp(&self, other: &Self) -> Ordering {
386                other
387                    .f_cost
388                    .partial_cmp(&self.f_cost)
389                    .unwrap_or(Ordering::Equal)
390            }
391        }
392
393        impl PartialOrd for State {
394            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
395                Some(self.cmp(other))
396            }
397        }
398
399        let mut g_score: HashMap<NodeId, f64> = HashMap::new();
400        let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
401        let mut heap = BinaryHeap::new();
402
403        g_score.insert(from, 0.0);
404        let h = map.euclidean_distance(from, to);
405        heap.push(State {
406            f_cost: h,
407            g_cost: 0.0,
408            node: from,
409        });
410
411        while let Some(State { g_cost, node, .. }) = heap.pop() {
412            if node == to {
413                let mut path = vec![to];
414                let mut current = to;
415
416                while let Some(&prev_node) = prev.get(&current) {
417                    path.push(prev_node);
418                    current = prev_node;
419                }
420
421                path.reverse();
422                return Some(Route {
423                    path,
424                    total_distance: g_cost,
425                });
426            }
427
428            if let Some(&g) = g_score.get(&node) {
429                if g_cost > g {
430                    continue;
431                }
432            }
433
434            for (neighbor, edge_id, length) in map.neighbors(node) {
435                let edge_cost = self.edge_cost(map, length, edge_id, traffic);
436                let tentative_g = g_cost + edge_cost;
437
438                if g_score.get(&neighbor).is_none_or(|&g| tentative_g < g) {
439                    g_score.insert(neighbor, tentative_g);
440                    prev.insert(neighbor, node);
441
442                    let h = map.euclidean_distance(neighbor, to);
443                    heap.push(State {
444                        f_cost: tentative_g + h,
445                        g_cost: tentative_g,
446                        node: neighbor,
447                    });
448                }
449            }
450        }
451
452        None
453    }
454
455    pub fn invalidate_cache(&mut self) {
456        self.cache.invalidate();
457    }
458}