tree_traversal/
gds.rs

1//! Greedy Search
2
3use num_traits::Bounded;
4
5use crate::bms::bms;
6
7/// Findthe leaf node with the lowest cost by using Greedy Search
8///
9/// - `start` is the start node.
10/// - `successor_fn` returns a list of successors for a given node.
11/// - `eval_fn` returns the approximated cost of a given node to sort and select k-best
12/// - `cost_fn` returns the final cost of a leaf node
13/// - `leaf_check_fn` check if a node is leaf or not
14/// - `max_ops` is the maximum number of search operations to perform
15///
16/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
17pub fn gds<N, IN, FN, FC1, FC2, C, FR>(
18    start: N,
19    successor_fn: FN,
20    eval_fn: FC1,
21    cost_fn: FC2,
22    leaf_check_fn: FR,
23    max_ops: usize,
24) -> Option<(C, N)>
25where
26    N: Clone,
27    IN: IntoIterator<Item = N>,
28    FN: FnMut(&N) -> IN,
29    FC1: Fn(&N) -> Option<C>,
30    FC2: Fn(&N) -> Option<C>,
31    C: Ord + Copy + Bounded,
32    FR: Fn(&N) -> bool,
33{
34    bms(
35        start,
36        successor_fn,
37        eval_fn,
38        usize::MAX,
39        1,
40        cost_fn,
41        leaf_check_fn,
42        max_ops,
43    )
44}
45
46#[cfg(test)]
47mod test {
48
49    use super::gds;
50
51    type CityId = usize;
52    type Duration = u32;
53
54    #[derive(Debug, PartialEq, Eq, Hash, Clone)]
55    struct Node {
56        pub city: CityId,
57        pub parents: Vec<CityId>,
58        pub children: Vec<CityId>,
59        pub t: Duration,
60    }
61
62    impl Node {
63        pub fn new(city: CityId, parents: Vec<CityId>, children: Vec<CityId>, t: Duration) -> Self {
64            Self {
65                city,
66                parents,
67                children,
68                t,
69            }
70        }
71
72        pub fn from_parent(
73            parent: &Self,
74            city: CityId,
75            time_func: &dyn Fn(CityId, CityId) -> Duration,
76        ) -> Self {
77            let parents = {
78                let mut _parents = parent.parents.clone();
79                _parents.push(parent.city);
80                _parents
81            };
82
83            let children = {
84                let mut _children = parent.children.clone();
85                let i = _children
86                    .iter()
87                    .copied()
88                    .enumerate()
89                    .find(|&(_, c)| c == city)
90                    .unwrap()
91                    .0;
92                _children.remove(i);
93                _children
94            };
95
96            let t = parent.t + time_func(parent.city, city);
97
98            Node {
99                city,
100                parents,
101                children,
102                t,
103            }
104        }
105
106        pub fn is_leaf(&self) -> bool {
107            self.children.is_empty()
108        }
109
110        pub fn generate_child_nodes(
111            &self,
112            time_func: &dyn Fn(CityId, CityId) -> Duration,
113        ) -> Vec<Self> {
114            let mut child_nodes = vec![];
115            for city in self.children.iter().copied() {
116                let node = Self::from_parent(self, city, time_func);
117                child_nodes.push(node);
118            }
119            child_nodes
120        }
121    }
122
123    #[test]
124    fn test_bms() {
125        let distance_matrix = [
126            [
127                0, 2451, 713, 1018, 1631, 1374, 2408, 213, 2571, 875, 1420, 2145, 1972,
128            ],
129            [
130                2451, 0, 1745, 1524, 831, 1240, 959, 2596, 403, 1589, 1374, 357, 579,
131            ],
132            [
133                713, 1745, 0, 355, 920, 803, 1737, 851, 1858, 262, 940, 1453, 1260,
134            ],
135            [
136                1018, 1524, 355, 0, 700, 862, 1395, 1123, 1584, 466, 1056, 1280, 987,
137            ],
138            [
139                1631, 831, 920, 700, 0, 663, 1021, 1769, 949, 796, 879, 586, 371,
140            ],
141            [
142                1374, 1240, 803, 862, 663, 0, 1681, 1551, 1765, 547, 225, 887, 999,
143            ],
144            [
145                2408, 959, 1737, 1395, 1021, 1681, 0, 2493, 678, 1724, 1891, 1114, 701,
146            ],
147            [
148                213, 2596, 851, 1123, 1769, 1551, 2493, 0, 2699, 1038, 1605, 2300, 2099,
149            ],
150            [
151                2571, 403, 1858, 1584, 949, 1765, 678, 2699, 0, 1744, 1645, 653, 600,
152            ],
153            [
154                875, 1589, 262, 466, 796, 547, 1724, 1038, 1744, 0, 679, 1272, 1162,
155            ],
156            [
157                1420, 1374, 940, 1056, 879, 225, 1891, 1605, 1645, 679, 0, 1017, 1200,
158            ],
159            [
160                2145, 357, 1453, 1280, 586, 887, 1114, 2300, 653, 1272, 1017, 0, 504,
161            ],
162            [
163                1972, 579, 1260, 987, 371, 999, 701, 2099, 600, 1162, 1200, 504, 0,
164            ],
165        ];
166
167        let n_cities = distance_matrix.len();
168
169        let start = 0;
170        let root_node = Node::new(start, vec![], (1..n_cities).collect(), 0);
171        let time_func = |p: CityId, c: CityId| distance_matrix[p][c];
172
173        let successor_fn = |n: &Node| n.generate_child_nodes(&time_func);
174        let eval_fn = |n: &Node| Some(n.t);
175
176        let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start));
177        let leaf_check_fn = |n: &Node| n.is_leaf();
178
179        let max_ops = usize::MAX;
180
181        let (cost, best_node) = gds(
182            root_node,
183            successor_fn,
184            eval_fn,
185            cost_fn,
186            leaf_check_fn,
187            max_ops,
188        )
189        .unwrap();
190
191        assert!(cost < 9000);
192        let mut visited_cities = best_node.parents.clone();
193        visited_cities.push(best_node.city);
194        visited_cities.sort();
195        let all_cities: Vec<CityId> = (0..n_cities).into_iter().collect();
196        assert_eq!(visited_cities, all_cities);
197    }
198}