screeps_pathfinding/algorithms/
dijkstra.rs

1// https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm
2
3// Sample code pulled (and modified) from: https://doc.rust-lang.org/nightly/std/collections/binary_heap/index.html#examples
4
5use std::cmp::Ordering;
6use std::collections::{BinaryHeap, HashMap};
7use std::hash::Hash;
8
9use screeps::local::{Position, RoomXY};
10
11/// A simple trait encapsulating what other traits are needed
12/// for a type to be usable in Dijkstra's Algorithm.
13pub trait DijkstraNode: Eq + Hash + Copy + Ord {}
14impl<T> DijkstraNode for T where T: Eq + Hash + Copy + Ord {}
15
16#[derive(Debug)]
17pub struct DijkstraSearchResults<T>
18where
19    T: DijkstraNode,
20{
21    ops_used: u32,
22    cost: u32,
23    incomplete: bool,
24    path: Vec<T>,
25}
26
27impl<T: DijkstraNode> DijkstraSearchResults<T> {
28    /// The number of expand node operations used
29    pub fn ops(&self) -> u32 {
30        self.ops_used
31    }
32
33    /// The movement cost of the result path
34    pub fn cost(&self) -> u32 {
35        self.cost
36    }
37
38    /// Whether the path contained is incomplete
39    pub fn incomplete(&self) -> bool {
40        self.incomplete
41    }
42
43    /// A shortest path from the start node to the goal node
44    pub fn path(&self) -> &[T] {
45        &self.path
46    }
47}
48
49#[derive(Copy, Clone, Eq, PartialEq)]
50struct State<T>
51where
52    T: Ord,
53{
54    cost: u32,
55    position: T,
56}
57
58// The priority queue depends on `Ord`.
59// Explicitly implement the trait so the queue becomes a min-heap
60// instead of a max-heap.
61impl<T> Ord for State<T>
62where
63    T: Ord,
64{
65    fn cmp(&self, other: &Self) -> Ordering {
66        // Notice that we flip the ordering on costs.
67        // In case of a tie we compare positions - this step is necessary
68        // to make implementations of `PartialEq` and `Ord` consistent.
69        other
70            .cost
71            .cmp(&self.cost)
72            .then_with(|| self.position.cmp(&other.position))
73    }
74}
75
76// `PartialOrd` needs to be implemented as well.
77impl<T> PartialOrd for State<T>
78where
79    T: Ord,
80{
81    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82        Some(self.cmp(other))
83    }
84}
85
86/// Calculates a shortest path from `start` to `goal` using Dijkstra's Algorithm.
87///
88/// The algorithm itself doesn't care what type the nodes are, as long as you provide
89/// a cost function that can convert a node of that type into a u32 cost to move
90/// to that node, and a neighbors function that can generate a slice of nodes to explore.
91///
92/// The cost function should return u32::MAX for unpassable tiles.
93///
94/// # Example
95/// ```rust
96/// use screeps::{LocalRoomTerrain, RoomXY};
97/// use screeps_pathfinding::utils::goals::goal_exact_node;
98///
99/// let start = RoomXY::checked_new(24, 18).unwrap();
100/// let goal = RoomXY::checked_new(34, 40).unwrap();
101/// let room_terrain = LocalRoomTerrain::new_from_bits(Box::new([0; 2500])); // Terrain that's all plains
102/// let plain_cost = 1;
103/// let swamp_cost = 5;
104/// let costs = screeps_pathfinding::utils::movement_costs::get_movement_cost_lcm_from_terrain(&room_terrain, plain_cost, swamp_cost);
105/// let costs_fn = screeps_pathfinding::utils::movement_costs::movement_costs_from_lcm(&costs);
106/// let neighbors_fn = screeps_pathfinding::utils::neighbors::room_xy_neighbors;
107/// let max_ops = 2000;
108/// let max_cost = 2000;
109///
110/// let search_results = screeps_pathfinding::algorithms::dijkstra::shortest_path_generic(
111///     &[start],
112///     &goal_exact_node(goal),
113///     costs_fn,
114///     neighbors_fn,
115///     max_ops,
116///     max_cost,
117/// );
118///
119/// if !search_results.incomplete() {
120///   let path = search_results.path();
121///   println!("Path: {:?}", path);
122/// }
123/// else {
124///   println!("Could not find Dijkstra shortest path.");
125///   println!("Search Results: {:?}", search_results);
126/// }
127/// ```
128///
129/// Reference: <https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>
130pub fn shortest_path_generic<T: DijkstraNode, P, G, N, I>(
131    start: &[T],
132    goal_fn: &P,
133    cost_fn: G,
134    neighbors: N,
135    max_ops: u32,
136    max_cost: u32,
137) -> DijkstraSearchResults<T>
138where
139    P: Fn(T) -> bool,
140    G: Fn(T) -> u32,
141    N: Fn(T) -> I,
142    I: IntoIterator<Item = T, IntoIter: Iterator<Item = T>>,
143{
144    // Start at `start` and use `dist` to track the current shortest distance
145    // to each node. This implementation isn't memory-efficient as it may leave duplicate
146    // nodes in the queue. It also uses `u32::MAX` as a sentinel value,
147    // for a simpler implementation.
148
149    let mut remaining_ops: u32 = max_ops;
150    let mut last_examined_cost: u32 = 0;
151
152    // dist[node] = current shortest distance from `start` to `node`
153    let mut dist: HashMap<T, u32> = HashMap::new();
154    let mut parents: HashMap<T, T> = HashMap::new();
155
156    let mut heap = BinaryHeap::new();
157
158    // We're at `start`, with a zero cost
159    for s in start {
160        dist.insert(*s, 0);
161        heap.push(State {
162            cost: 0,
163            position: *s,
164        });
165    }
166
167    // Examine the frontier with lower cost nodes first (min-heap)
168    while let Some(State { cost, position }) = heap.pop() {
169        // We found the goal state, return the search results
170        if goal_fn(position) {
171            let path_opt = get_path_from_parents(&parents, position);
172            return DijkstraSearchResults {
173                ops_used: max_ops - remaining_ops,
174                cost,
175                incomplete: path_opt.is_none(),
176                path: path_opt.unwrap_or_else(|| Vec::new()),
177            };
178        }
179
180        remaining_ops -= 1;
181
182        // Stop searching if we've run out of remaining ops we're allowed to perform
183        if remaining_ops == 0 {
184            break;
185        }
186
187        // Stop searching if our current cost is greater than what we're willing to pay
188        // Note: Because our heap is sorted by cost, no later nodes can have a smaller
189        // cost, so we can safely break here
190        last_examined_cost = cost;
191        if cost >= max_cost {
192            break;
193        }
194
195        // Important as we may have already found a better way
196        let current_cost = match dist.get(&position) {
197            Some(c) => c,
198            None => &u32::MAX,
199        };
200        if cost > *current_cost {
201            continue;
202        }
203
204        // For each node we can reach, see if we can find a way with
205        // a lower cost going through this node
206        for p in neighbors(position) {
207            let next_tile_cost = cost_fn(p);
208
209            // u32::MAX is our sentinel value for unpassable, skip this neighbor
210            if next_tile_cost == u32::MAX {
211                continue;
212            }
213
214            let next = State {
215                cost: cost + next_tile_cost,
216                position: p,
217            };
218
219            // If so, add it to the frontier and continue
220            let current_next_cost = match dist.get(&next.position) {
221                Some(c) => c,
222                None => &u32::MAX,
223            };
224            if next.cost < *current_next_cost {
225                heap.push(next);
226
227                // Relaxation, we have now found a better way
228                if let Some(c) = dist.get_mut(&next.position) {
229                    *c = next.cost;
230                } else {
231                    dist.insert(next.position, next.cost);
232                }
233
234                if let Some(parent_node) = parents.get_mut(&next.position) {
235                    *parent_node = position;
236                } else {
237                    parents.insert(next.position, position);
238                }
239            }
240        }
241    }
242
243    // Goal not reachable
244    DijkstraSearchResults {
245        ops_used: max_ops - remaining_ops,
246        cost: last_examined_cost,
247        incomplete: true,
248        path: Vec::new(),
249    }
250}
251
252fn get_path_from_parents<T: DijkstraNode>(parents: &HashMap<T, T>, end: T) -> Option<Vec<T>> {
253    let mut path = Vec::new();
254
255    let mut current_pos = end;
256
257    path.push(end);
258
259    let mut parent_opt = parents.get(&current_pos);
260    while parent_opt.is_some() {
261        let parent = parent_opt.unwrap();
262        path.push(*parent);
263        current_pos = *parent;
264        parent_opt = parents.get(&current_pos);
265    }
266
267    Some(path.into_iter().rev().collect())
268}
269
270pub fn shortest_path_roomxy<P, G>(
271    start: RoomXY,
272    goal_fn: &P,
273    cost_fn: G,
274) -> DijkstraSearchResults<RoomXY>
275where
276    P: Fn(RoomXY) -> bool,
277    G: Fn(RoomXY) -> u32,
278{
279    shortest_path_roomxy_multistart(&[start], goal_fn, cost_fn)
280}
281
282pub fn shortest_path_roomxy_multistart<P, G>(
283    start_nodes: &[RoomXY],
284    goal_fn: &P,
285    cost_fn: G,
286) -> DijkstraSearchResults<RoomXY>
287where
288    P: Fn(RoomXY) -> bool,
289    G: Fn(RoomXY) -> u32,
290{
291    let neighbors_fn = crate::utils::neighbors::room_xy_neighbors;
292    let max_ops = 2000;
293    let max_cost = 2000;
294    shortest_path_generic(
295        start_nodes,
296        goal_fn,
297        cost_fn,
298        neighbors_fn,
299        max_ops,
300        max_cost,
301    )
302}
303
304pub fn shortest_path_position<P, G>(
305    start: Position,
306    goal_fn: &P,
307    cost_fn: G,
308) -> DijkstraSearchResults<Position>
309where
310    P: Fn(Position) -> bool,
311    G: Fn(Position) -> u32,
312{
313    shortest_path_position_multistart(&[start], goal_fn, cost_fn)
314}
315
316pub fn shortest_path_position_multistart<P, G>(
317    start_nodes: &[Position],
318    goal_fn: &P,
319    cost_fn: G,
320) -> DijkstraSearchResults<Position>
321where
322    P: Fn(Position) -> bool,
323    G: Fn(Position) -> u32,
324{
325    let neighbors_fn = crate::utils::neighbors::position_neighbors;
326    let max_ops = 2000;
327    let max_cost = 2000;
328    shortest_path_generic(
329        start_nodes,
330        goal_fn,
331        cost_fn,
332        neighbors_fn,
333        max_ops,
334        max_cost,
335    )
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::utils::goals::goal_exact_node;
342    use screeps::constants::Direction;
343    use screeps::local::{Position, RoomCoordinate, RoomXY};
344
345    // Helper Functions
346
347    fn new_position(room_name: &str, x: u8, y: u8) -> Position {
348        Position::new(
349            RoomCoordinate::try_from(x).unwrap(),
350            RoomCoordinate::try_from(y).unwrap(),
351            room_name.parse().unwrap(),
352        )
353    }
354
355    fn all_tiles_are_plains_costs<T>(_node: T) -> u32 {
356        1
357    }
358
359    fn all_tiles_are_swamps_costs<T>(_node: T) -> u32 {
360        5
361    }
362
363    fn room_xy_neighbors(node: RoomXY) -> Vec<RoomXY> {
364        node.neighbors()
365    }
366
367    fn position_neighbors(node: Position) -> Vec<Position> {
368        Direction::iter()
369            .filter_map(|dir| node.checked_add_direction(*dir).ok())
370            .collect()
371    }
372
373    // Testing function where all tiles are reachable except for (10, 12)
374    fn roomxy_unreachable_tile_costs(node: RoomXY) -> u32 {
375        if node.x.u8() == 10 && node.y.u8() == 12 {
376            u32::MAX
377        } else {
378            1
379        }
380    }
381
382    // Testing function where all tiles are reachable except for (10, 12)
383    fn position_unreachable_tile_costs(node: Position) -> u32 {
384        if node.x().u8() == 10 && node.y().u8() == 12 {
385            u32::MAX
386        } else {
387            1
388        }
389    }
390
391    // Test Cases
392
393    #[test]
394    fn simple_linear_path_roomxy() {
395        let start = unsafe { RoomXY::unchecked_new(10, 10) };
396        let goal = unsafe { RoomXY::unchecked_new(10, 12) };
397        let goal_fn = goal_exact_node(goal);
398        let search_results = shortest_path_generic(
399            &[start],
400            &goal_fn,
401            all_tiles_are_plains_costs,
402            room_xy_neighbors,
403            2000,
404            2000,
405        );
406
407        assert_eq!(search_results.incomplete(), false);
408        assert_eq!(search_results.cost(), 2);
409        assert_eq!(search_results.ops() < 2000, true);
410
411        let path = search_results.path();
412
413        assert_eq!(path.len(), 3);
414
415        // All three of these nodes are on a shortest path, so we
416        // can't guarantee that we'll get any specific one of them
417        let middle_node_1 = unsafe { RoomXY::unchecked_new(10, 11) };
418        let middle_node_2 = unsafe { RoomXY::unchecked_new(11, 11) };
419        let middle_node_3 = unsafe { RoomXY::unchecked_new(11, 10) };
420
421        assert_eq!(path.contains(&start), true);
422        assert_eq!(path.contains(&goal), true);
423
424        let contains_a_middle_node = path.contains(&middle_node_1)
425            | path.contains(&middle_node_2)
426            | path.contains(&middle_node_3);
427        assert_eq!(contains_a_middle_node, true);
428    }
429
430    #[test]
431    fn simple_linear_path_position() {
432        let room_name = "E5N6";
433        let start = new_position(room_name, 10, 10);
434        let goal = new_position(room_name, 10, 12);
435        let goal_fn = goal_exact_node(goal);
436        let search_results = shortest_path_generic(
437            &[start],
438            &goal_fn,
439            all_tiles_are_plains_costs,
440            position_neighbors,
441            2000,
442            2000,
443        );
444
445        assert_eq!(search_results.incomplete(), false);
446        assert_eq!(search_results.cost(), 2);
447        assert_eq!(search_results.ops() < 2000, true);
448
449        let path = search_results.path();
450
451        assert_eq!(path.len(), 3);
452
453        // All three of these nodes are on a shortest path, so we
454        // can't guarantee that we'll get any specific one of them
455        let middle_node_1 = new_position(room_name, 10, 11);
456        let middle_node_2 = new_position(room_name, 11, 11);
457        let middle_node_3 = new_position(room_name, 11, 10);
458
459        assert_eq!(path.contains(&start), true);
460        assert_eq!(path.contains(&goal), true);
461
462        let contains_a_middle_node = path.contains(&middle_node_1)
463            | path.contains(&middle_node_2)
464            | path.contains(&middle_node_3);
465        assert_eq!(contains_a_middle_node, true);
466    }
467
468    #[test]
469    fn unreachable_target_roomxy() {
470        let start = unsafe { RoomXY::unchecked_new(10, 10) };
471        let goal_fn = goal_exact_node(unsafe { RoomXY::unchecked_new(10, 12) });
472        let search_results = shortest_path_generic(
473            &[start],
474            &goal_fn,
475            roomxy_unreachable_tile_costs,
476            room_xy_neighbors,
477            2000,
478            2000,
479        );
480
481        println!("{:?}", search_results);
482
483        assert_eq!(search_results.incomplete(), true);
484        assert_eq!(search_results.cost() > 0, true);
485        assert_eq!(search_results.ops() == 2000, true);
486
487        let path = search_results.path();
488
489        assert_eq!(path.len(), 0);
490    }
491
492    #[test]
493    fn unreachable_target_position() {
494        let room_name = "E5N6";
495        let start = new_position(room_name, 10, 10);
496        let goal_fn = goal_exact_node(new_position(room_name, 10, 12));
497        let search_results = shortest_path_generic(
498            &[start],
499            &goal_fn,
500            position_unreachable_tile_costs,
501            position_neighbors,
502            2000,
503            2000,
504        );
505
506        println!("{:?}", search_results);
507
508        assert_eq!(search_results.incomplete(), true);
509        assert_eq!(search_results.cost() > 0, true);
510        assert_eq!(search_results.ops() == 2000, true);
511
512        let path = search_results.path();
513
514        assert_eq!(path.len(), 0);
515    }
516
517    #[test]
518    fn max_ops_halt_roomxy() {
519        let max_ops_failure = 5;
520        let max_ops_success = 100;
521        let start = unsafe { RoomXY::unchecked_new(10, 10) };
522        let goal_fn = goal_exact_node(unsafe { RoomXY::unchecked_new(10, 12) }); // This target generally takes ~11 ops to find
523
524        // Failure case
525        let search_results = shortest_path_generic(
526            &[start],
527            &goal_fn,
528            all_tiles_are_plains_costs,
529            room_xy_neighbors,
530            max_ops_failure,
531            2000,
532        );
533
534        assert_eq!(search_results.incomplete(), true);
535        assert_eq!(search_results.cost() > 0, true);
536        assert_eq!(search_results.ops() == max_ops_failure, true);
537
538        let path = search_results.path();
539
540        assert_eq!(path.len(), 0);
541
542        // Success case
543        let search_results = shortest_path_generic(
544            &[start],
545            &goal_fn,
546            all_tiles_are_plains_costs,
547            room_xy_neighbors,
548            max_ops_success,
549            2000,
550        );
551
552        assert_eq!(search_results.incomplete(), false);
553        assert_eq!(search_results.cost() > 0, true);
554        assert_eq!(search_results.ops() < max_ops_success, true);
555
556        let path = search_results.path();
557
558        assert_eq!(path.len(), 3);
559    }
560
561    #[test]
562    fn max_ops_halt_position() {
563        let max_ops_failure = 5;
564        let max_ops_success = 100;
565        let room_name = "E5N6";
566        let start = new_position(room_name, 10, 10);
567        let goal_fn = goal_exact_node(new_position(room_name, 10, 12)); // This target generally takes ~11 ops to find
568
569        // Failure case
570        let search_results = shortest_path_generic(
571            &[start],
572            &goal_fn,
573            all_tiles_are_plains_costs,
574            position_neighbors,
575            max_ops_failure,
576            2000,
577        );
578
579        assert_eq!(search_results.incomplete(), true);
580        assert_eq!(search_results.cost() > 0, true);
581        assert_eq!(search_results.ops() == max_ops_failure, true);
582
583        let path = search_results.path();
584
585        assert_eq!(path.len(), 0);
586
587        // Success case
588        let search_results = shortest_path_generic(
589            &[start],
590            &goal_fn,
591            all_tiles_are_plains_costs,
592            position_neighbors,
593            max_ops_success,
594            2000,
595        );
596
597        assert_eq!(search_results.incomplete(), false);
598        assert_eq!(search_results.cost() > 0, true);
599        assert_eq!(search_results.ops() < max_ops_success, true);
600
601        let path = search_results.path();
602
603        assert_eq!(path.len(), 3);
604    }
605
606    #[test]
607    fn max_cost_halt_roomxy() {
608        let max_cost_failure = 5;
609        let max_cost_success = 100;
610        let start = unsafe { RoomXY::unchecked_new(10, 10) };
611        let goal_fn = goal_exact_node(unsafe { RoomXY::unchecked_new(10, 12) }); // This target will cost 10 to move to
612
613        // Failure case
614        let search_results = shortest_path_generic(
615            &[start],
616            &goal_fn,
617            all_tiles_are_swamps_costs,
618            room_xy_neighbors,
619            2000,
620            max_cost_failure,
621        );
622
623        assert_eq!(search_results.incomplete(), true);
624        assert_eq!(search_results.cost() >= max_cost_failure, true);
625        assert_eq!(search_results.ops() < 2000, true);
626
627        let path = search_results.path();
628
629        assert_eq!(path.len(), 0);
630
631        // Success case
632        let search_results = shortest_path_generic(
633            &[start],
634            &goal_fn,
635            all_tiles_are_swamps_costs,
636            room_xy_neighbors,
637            2000,
638            max_cost_success,
639        );
640
641        assert_eq!(search_results.incomplete(), false);
642        assert_eq!(search_results.cost() < max_cost_success, true);
643        assert_eq!(search_results.ops() < 2000, true);
644
645        let path = search_results.path();
646
647        assert_eq!(path.len(), 3);
648    }
649
650    #[test]
651    fn max_cost_halt_position() {
652        let max_cost_failure = 5;
653        let max_cost_success = 100;
654        let room_name = "E5N6";
655        let start = new_position(room_name, 10, 10);
656        let goal_fn = goal_exact_node(new_position(room_name, 10, 12)); // This target will cost 10 to move to
657
658        // Failure case
659        let search_results = shortest_path_generic(
660            &[start],
661            &goal_fn,
662            all_tiles_are_swamps_costs,
663            position_neighbors,
664            2000,
665            max_cost_failure,
666        );
667
668        assert_eq!(search_results.incomplete(), true);
669        assert_eq!(search_results.cost() >= max_cost_failure, true);
670        assert_eq!(search_results.ops() < 2000, true);
671
672        let path = search_results.path();
673
674        assert_eq!(path.len(), 0);
675
676        // Success case
677        let search_results = shortest_path_generic(
678            &[start],
679            &goal_fn,
680            all_tiles_are_swamps_costs,
681            position_neighbors,
682            2000,
683            max_cost_success,
684        );
685
686        assert_eq!(search_results.incomplete(), false);
687        assert_eq!(search_results.cost() < max_cost_success, true);
688        assert_eq!(search_results.ops() < 2000, true);
689
690        let path = search_results.path();
691
692        assert_eq!(path.len(), 3);
693    }
694}