rustsim_pathfinding/
route_choice.rs1use rand::Rng;
31
32pub fn mnl_select<R: Rng>(costs: &[f64], theta: f64, rng: &mut R) -> usize {
52 if costs.len() <= 1 {
53 return 0;
54 }
55
56 let max_util = costs
59 .iter()
60 .copied()
61 .map(|c| -theta * c)
62 .fold(f64::NEG_INFINITY, f64::max);
63
64 let exp_utils: Vec<f64> = costs
65 .iter()
66 .map(|&c| (-theta * c - max_util).exp())
67 .collect();
68 let sum: f64 = exp_utils.iter().sum();
69
70 if sum <= 0.0 || !sum.is_finite() {
71 return 0; }
73
74 let r: f64 = rng.gen::<f64>() * sum;
76 let mut cumulative = 0.0;
77 for (i, &u) in exp_utils.iter().enumerate() {
78 cumulative += u;
79 if r <= cumulative {
80 return i;
81 }
82 }
83 costs.len() - 1
84}
85
86pub fn mnl_probabilities(costs: &[f64], theta: f64) -> Vec<f64> {
95 if costs.is_empty() {
96 return Vec::new();
97 }
98 if costs.len() == 1 {
99 return vec![1.0];
100 }
101
102 let max_util = costs
103 .iter()
104 .copied()
105 .map(|c| -theta * c)
106 .fold(f64::NEG_INFINITY, f64::max);
107
108 let exp_utils: Vec<f64> = costs
109 .iter()
110 .map(|&c| (-theta * c - max_util).exp())
111 .collect();
112 let sum: f64 = exp_utils.iter().sum();
113
114 if sum <= 0.0 || !sum.is_finite() {
115 let mut probs = vec![0.0; costs.len()];
116 probs[0] = 1.0;
117 return probs;
118 }
119
120 exp_utils.iter().map(|&u| u / sum).collect()
121}
122
123pub fn mnl_logsum(costs: &[f64], theta: f64) -> f64 {
130 if costs.is_empty() || theta <= 0.0 {
131 return f64::INFINITY;
132 }
133
134 let max_util = costs
135 .iter()
136 .copied()
137 .map(|c| -theta * c)
138 .fold(f64::NEG_INFINITY, f64::max);
139
140 let sum_exp: f64 = costs.iter().map(|&c| (-theta * c - max_util).exp()).sum();
141
142 if sum_exp <= 0.0 || !sum_exp.is_finite() {
143 return *costs
144 .iter()
145 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
146 .unwrap_or(&f64::INFINITY);
147 }
148
149 -(max_util + sum_exp.ln()) / theta
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use rand::rngs::StdRng;
156 use rand::SeedableRng;
157
158 #[test]
159 fn single_cost_returns_zero() {
160 let mut rng = StdRng::seed_from_u64(42);
161 assert_eq!(mnl_select(&[10.0], 1.0, &mut rng), 0);
162 }
163
164 #[test]
165 fn empty_costs_returns_zero() {
166 let mut rng = StdRng::seed_from_u64(42);
167 assert_eq!(mnl_select(&[], 1.0, &mut rng), 0);
168 }
169
170 #[test]
171 fn high_theta_favors_cheapest() {
172 let costs = [10.0, 20.0, 30.0];
173 let mut rng = StdRng::seed_from_u64(42);
174 let mut count_cheapest = 0;
175 for _ in 0..1000 {
176 if mnl_select(&costs, 100.0, &mut rng) == 0 {
177 count_cheapest += 1;
178 }
179 }
180 assert!(
182 count_cheapest > 990,
183 "Expected nearly all cheapest, got {count_cheapest}/1000"
184 );
185 }
186
187 #[test]
188 fn low_theta_spreads_selection() {
189 let costs = [10.0, 10.5, 11.0];
190 let mut rng = StdRng::seed_from_u64(42);
191 let mut counts = [0usize; 3];
192 let trials = 3000;
193 for _ in 0..trials {
194 counts[mnl_select(&costs, 0.1, &mut rng)] += 1;
195 }
196 for &c in &counts {
198 assert!(
199 c > trials / 10,
200 "Expected spread selection, got counts {:?}",
201 counts
202 );
203 }
204 }
205
206 #[test]
207 fn probabilities_sum_to_one() {
208 let costs = [10.0, 12.0, 15.0];
209 let probs = mnl_probabilities(&costs, 1.0);
210 assert_eq!(probs.len(), 3);
211 let sum: f64 = probs.iter().sum();
212 assert!((sum - 1.0).abs() < 1e-10);
213 }
214
215 #[test]
216 fn probabilities_decrease_with_cost() {
217 let costs = [5.0, 10.0, 20.0];
218 let probs = mnl_probabilities(&costs, 1.0);
219 assert!(probs[0] > probs[1]);
220 assert!(probs[1] > probs[2]);
221 }
222
223 #[test]
224 fn logsum_bounded_by_min_cost() {
225 let costs = [10.0, 15.0, 20.0];
226 let ls = mnl_logsum(&costs, 1.0);
227 assert!(ls <= 10.0 + 1e-6, "logsum={ls} should be <= 10.0");
229 }
230
231 #[test]
232 fn logsum_decreases_with_more_alternatives() {
233 let costs_2 = [10.0, 15.0];
234 let costs_3 = [10.0, 15.0, 12.0];
235 let ls_2 = mnl_logsum(&costs_2, 1.0);
236 let ls_3 = mnl_logsum(&costs_3, 1.0);
237 assert!(
238 ls_3 <= ls_2 + 1e-6,
239 "More alternatives should decrease logsum: {ls_3} vs {ls_2}"
240 );
241 }
242}