1use crate::graph::WarehouseMap;
4use crate::traffic::TrafficManager;
5use std::cmp::Ordering;
6use std::collections::{BinaryHeap, HashMap, HashSet};
7use waremax_core::{EdgeId, NodeId};
8
9#[derive(Clone, Debug, Default, PartialEq)]
11pub enum RoutingAlgorithm {
12 #[default]
13 Dijkstra,
14 AStar,
15}
16
17#[derive(Clone, Debug)]
19pub struct Route {
20 pub path: Vec<NodeId>,
21 pub total_distance: f64,
22}
23
24impl Route {
25 pub fn empty(start: NodeId) -> Self {
26 Self {
27 path: vec![start],
28 total_distance: 0.0,
29 }
30 }
31
32 pub fn is_empty(&self) -> bool {
33 self.path.len() <= 1
34 }
35
36 pub fn len(&self) -> usize {
37 self.path.len()
38 }
39}
40
41pub struct RouteCache {
43 cache: HashMap<(NodeId, NodeId), Route>,
44 max_size: usize,
45}
46
47impl RouteCache {
48 pub fn new(max_size: usize) -> Self {
49 Self {
50 cache: HashMap::new(),
51 max_size,
52 }
53 }
54
55 pub fn get(&self, from: NodeId, to: NodeId) -> Option<&Route> {
56 self.cache.get(&(from, to))
57 }
58
59 pub fn insert(&mut self, from: NodeId, to: NodeId, route: Route) {
60 if self.cache.len() >= self.max_size {
61 let keys: Vec<_> = self.cache.keys().take(self.max_size / 2).copied().collect();
62 for key in keys {
63 self.cache.remove(&key);
64 }
65 }
66 self.cache.insert((from, to), route);
67 }
68
69 pub fn invalidate(&mut self) {
70 self.cache.clear();
71 }
72}
73
74pub struct Router {
76 cache: RouteCache,
77 cache_enabled: bool,
78 algorithm: RoutingAlgorithm,
80 congestion_weight: f64,
82}
83
84impl Router {
85 pub fn new(cache_enabled: bool) -> Self {
86 Self {
87 cache: RouteCache::new(10000),
88 cache_enabled,
89 algorithm: RoutingAlgorithm::default(),
90 congestion_weight: 0.0,
91 }
92 }
93
94 pub fn with_algorithm(cache_enabled: bool, algorithm: RoutingAlgorithm) -> Self {
96 Self {
97 cache: RouteCache::new(10000),
98 cache_enabled,
99 algorithm,
100 congestion_weight: 0.0,
101 }
102 }
103
104 pub fn set_congestion_weight(&mut self, weight: f64) {
106 self.congestion_weight = weight;
107 }
108
109 pub fn algorithm(&self) -> &RoutingAlgorithm {
111 &self.algorithm
112 }
113
114 pub fn find_route(&mut self, map: &WarehouseMap, from: NodeId, to: NodeId) -> Option<Route> {
115 if from == to {
116 return Some(Route::empty(from));
117 }
118
119 if self.cache_enabled {
120 if let Some(route) = self.cache.get(from, to) {
121 return Some(route.clone());
122 }
123 }
124
125 let route = match self.algorithm {
126 RoutingAlgorithm::Dijkstra => self.dijkstra(map, from, to, None),
127 RoutingAlgorithm::AStar => self.astar(map, from, to, None),
128 }?;
129
130 if self.cache_enabled {
131 self.cache.insert(from, to, route.clone());
132 }
133
134 Some(route)
135 }
136
137 pub fn find_route_with_traffic(
139 &mut self,
140 map: &WarehouseMap,
141 from: NodeId,
142 to: NodeId,
143 traffic: &TrafficManager,
144 ) -> Option<Route> {
145 if from == to {
146 return Some(Route::empty(from));
147 }
148
149 match self.algorithm {
151 RoutingAlgorithm::Dijkstra => self.dijkstra(map, from, to, Some(traffic)),
152 RoutingAlgorithm::AStar => self.astar(map, from, to, Some(traffic)),
153 }
154 }
155
156 pub fn find_route_avoiding(
158 &mut self,
159 map: &WarehouseMap,
160 from: NodeId,
161 to: NodeId,
162 avoid_edges: &[EdgeId],
163 traffic: Option<&TrafficManager>,
164 ) -> Option<Route> {
165 if from == to {
166 return Some(Route::empty(from));
167 }
168
169 let avoid_set: HashSet<EdgeId> = avoid_edges.iter().copied().collect();
170 self.dijkstra_avoiding(map, from, to, &avoid_set, traffic)
171 }
172
173 fn edge_cost(
175 &self,
176 map: &WarehouseMap,
177 length: f64,
178 edge_id: EdgeId,
179 traffic: Option<&TrafficManager>,
180 ) -> f64 {
181 let speed_multiplier = map
183 .get_edge(edge_id)
184 .map(|e| e.speed_multiplier)
185 .unwrap_or(1.0);
186 let base_cost = length * speed_multiplier;
187
188 if self.congestion_weight > 0.0 {
190 if let Some(tm) = traffic {
191 let occupancy = tm.get_edge_occupancy(edge_id);
192 return base_cost * (1.0 + self.congestion_weight * occupancy as f64);
193 }
194 }
195 base_cost
196 }
197
198 fn dijkstra(
199 &self,
200 map: &WarehouseMap,
201 from: NodeId,
202 to: NodeId,
203 traffic: Option<&TrafficManager>,
204 ) -> Option<Route> {
205 #[derive(Clone, PartialEq)]
206 struct State {
207 cost: f64,
208 node: NodeId,
209 }
210
211 impl Eq for State {}
212
213 impl Ord for State {
214 fn cmp(&self, other: &Self) -> Ordering {
215 other
216 .cost
217 .partial_cmp(&self.cost)
218 .unwrap_or(Ordering::Equal)
219 }
220 }
221
222 impl PartialOrd for State {
223 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
224 Some(self.cmp(other))
225 }
226 }
227
228 let mut dist: HashMap<NodeId, f64> = HashMap::new();
229 let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
230 let mut heap = BinaryHeap::new();
231
232 dist.insert(from, 0.0);
233 heap.push(State {
234 cost: 0.0,
235 node: from,
236 });
237
238 while let Some(State { cost, node }) = heap.pop() {
239 if node == to {
240 let mut path = vec![to];
241 let mut current = to;
242
243 while let Some(&prev_node) = prev.get(¤t) {
244 path.push(prev_node);
245 current = prev_node;
246 }
247
248 path.reverse();
249 return Some(Route {
250 path,
251 total_distance: cost,
252 });
253 }
254
255 if let Some(&d) = dist.get(&node) {
256 if cost > d {
257 continue;
258 }
259 }
260
261 for (neighbor, edge_id, length) in map.neighbors(node) {
262 let edge_cost = self.edge_cost(map, length, edge_id, traffic);
263 let next_cost = cost + edge_cost;
264
265 if dist.get(&neighbor).is_none_or(|&d| next_cost < d) {
266 dist.insert(neighbor, next_cost);
267 prev.insert(neighbor, node);
268 heap.push(State {
269 cost: next_cost,
270 node: neighbor,
271 });
272 }
273 }
274 }
275
276 None
277 }
278
279 fn dijkstra_avoiding(
281 &self,
282 map: &WarehouseMap,
283 from: NodeId,
284 to: NodeId,
285 avoid_edges: &HashSet<EdgeId>,
286 traffic: Option<&TrafficManager>,
287 ) -> Option<Route> {
288 #[derive(Clone, PartialEq)]
289 struct State {
290 cost: f64,
291 node: NodeId,
292 }
293
294 impl Eq for State {}
295
296 impl Ord for State {
297 fn cmp(&self, other: &Self) -> Ordering {
298 other
299 .cost
300 .partial_cmp(&self.cost)
301 .unwrap_or(Ordering::Equal)
302 }
303 }
304
305 impl PartialOrd for State {
306 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
307 Some(self.cmp(other))
308 }
309 }
310
311 let mut dist: HashMap<NodeId, f64> = HashMap::new();
312 let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
313 let mut heap = BinaryHeap::new();
314
315 dist.insert(from, 0.0);
316 heap.push(State {
317 cost: 0.0,
318 node: from,
319 });
320
321 while let Some(State { cost, node }) = heap.pop() {
322 if node == to {
323 let mut path = vec![to];
324 let mut current = to;
325
326 while let Some(&prev_node) = prev.get(¤t) {
327 path.push(prev_node);
328 current = prev_node;
329 }
330
331 path.reverse();
332 return Some(Route {
333 path,
334 total_distance: cost,
335 });
336 }
337
338 if let Some(&d) = dist.get(&node) {
339 if cost > d {
340 continue;
341 }
342 }
343
344 for (neighbor, edge_id, length) in map.neighbors(node) {
345 if avoid_edges.contains(&edge_id) {
347 continue;
348 }
349
350 let edge_cost = self.edge_cost(map, length, edge_id, traffic);
351 let next_cost = cost + edge_cost;
352
353 if dist.get(&neighbor).is_none_or(|&d| next_cost < d) {
354 dist.insert(neighbor, next_cost);
355 prev.insert(neighbor, node);
356 heap.push(State {
357 cost: next_cost,
358 node: neighbor,
359 });
360 }
361 }
362 }
363
364 None
365 }
366
367 fn astar(
369 &self,
370 map: &WarehouseMap,
371 from: NodeId,
372 to: NodeId,
373 traffic: Option<&TrafficManager>,
374 ) -> Option<Route> {
375 #[derive(Clone, PartialEq)]
376 struct State {
377 f_cost: f64, g_cost: f64, node: NodeId,
380 }
381
382 impl Eq for State {}
383
384 impl Ord for State {
385 fn cmp(&self, other: &Self) -> Ordering {
386 other
387 .f_cost
388 .partial_cmp(&self.f_cost)
389 .unwrap_or(Ordering::Equal)
390 }
391 }
392
393 impl PartialOrd for State {
394 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
395 Some(self.cmp(other))
396 }
397 }
398
399 let mut g_score: HashMap<NodeId, f64> = HashMap::new();
400 let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
401 let mut heap = BinaryHeap::new();
402
403 g_score.insert(from, 0.0);
404 let h = map.euclidean_distance(from, to);
405 heap.push(State {
406 f_cost: h,
407 g_cost: 0.0,
408 node: from,
409 });
410
411 while let Some(State { g_cost, node, .. }) = heap.pop() {
412 if node == to {
413 let mut path = vec![to];
414 let mut current = to;
415
416 while let Some(&prev_node) = prev.get(¤t) {
417 path.push(prev_node);
418 current = prev_node;
419 }
420
421 path.reverse();
422 return Some(Route {
423 path,
424 total_distance: g_cost,
425 });
426 }
427
428 if let Some(&g) = g_score.get(&node) {
429 if g_cost > g {
430 continue;
431 }
432 }
433
434 for (neighbor, edge_id, length) in map.neighbors(node) {
435 let edge_cost = self.edge_cost(map, length, edge_id, traffic);
436 let tentative_g = g_cost + edge_cost;
437
438 if g_score.get(&neighbor).is_none_or(|&g| tentative_g < g) {
439 g_score.insert(neighbor, tentative_g);
440 prev.insert(neighbor, node);
441
442 let h = map.euclidean_distance(neighbor, to);
443 heap.push(State {
444 f_cost: tentative_g + h,
445 g_cost: tentative_g,
446 node: neighbor,
447 });
448 }
449 }
450 }
451
452 None
453 }
454
455 pub fn invalidate_cache(&mut self) {
456 self.cache.invalidate();
457 }
458}