1#[cfg(test)]
2#[path = "../../../tests/unit/models/problem/costs_test.rs"]
3mod costs_test;
4
5use crate::models::common::*;
6use crate::models::solution::{Activity, Route};
7use rosomaxa::prelude::{Float, GenericError, GenericResult};
8use rosomaxa::utils::CollectGroupBy;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Copy, Clone)]
14pub enum TravelTime {
15 Arrival(Timestamp),
17 Departure(Timestamp),
19}
20
21pub trait ActivityCost: Send + Sync {
23 fn cost(&self, route: &Route, activity: &Activity, arrival: Timestamp) -> Cost {
25 let actor = route.actor.as_ref();
26
27 let waiting = if activity.place.time.start > arrival { activity.place.time.start - arrival } else { 0. };
28 let service = activity.place.duration;
29
30 waiting * (actor.driver.costs.per_waiting_time + actor.vehicle.costs.per_waiting_time)
31 + service * (actor.driver.costs.per_service_time + actor.vehicle.costs.per_service_time)
32 }
33
34 fn estimate_departure(&self, route: &Route, activity: &Activity, arrival: Timestamp) -> Timestamp;
36
37 fn estimate_arrival(&self, route: &Route, activity: &Activity, departure: Timestamp) -> Timestamp;
39}
40
41#[derive(Default)]
43pub struct SimpleActivityCost {}
44
45impl ActivityCost for SimpleActivityCost {
46 fn estimate_departure(&self, _: &Route, activity: &Activity, arrival: Timestamp) -> Timestamp {
47 arrival.max(activity.place.time.start) + activity.place.duration
48 }
49
50 fn estimate_arrival(&self, _: &Route, activity: &Activity, departure: Timestamp) -> Timestamp {
51 activity.place.time.end.min(departure - activity.place.duration)
52 }
53}
54
55pub trait TransportCost: Send + Sync {
57 fn cost(&self, route: &Route, from: Location, to: Location, travel_time: TravelTime) -> Cost {
59 let actor = route.actor.as_ref();
60
61 let distance = self.distance(route, from, to, travel_time);
62 let duration = self.duration(route, from, to, travel_time);
63
64 distance * (actor.driver.costs.per_distance + actor.vehicle.costs.per_distance)
65 + duration * (actor.driver.costs.per_driving_time + actor.vehicle.costs.per_driving_time)
66 }
67
68 fn duration_approx(&self, profile: &Profile, from: Location, to: Location) -> Duration;
70
71 fn distance_approx(&self, profile: &Profile, from: Location, to: Location) -> Distance;
73
74 fn duration(&self, route: &Route, from: Location, to: Location, travel_time: TravelTime) -> Duration;
76
77 fn distance(&self, route: &Route, from: Location, to: Location, travel_time: TravelTime) -> Distance;
79}
80
81pub struct SimpleTransportCost {
84 durations: Vec<Duration>,
85 distances: Vec<Distance>,
86 size: usize,
87}
88
89impl SimpleTransportCost {
90 pub fn new(durations: Vec<Duration>, distances: Vec<Distance>) -> GenericResult<Self> {
92 let size = (durations.len() as Float).sqrt().round() as usize;
93
94 if (distances.len() as Float).sqrt().round() as usize != size {
95 return Err("distance-duration lengths don't match".into());
96 }
97
98 Ok(Self { durations, distances, size })
99 }
100}
101
102impl TransportCost for SimpleTransportCost {
103 fn duration_approx(&self, _: &Profile, from: Location, to: Location) -> Duration {
104 self.durations.get(from * self.size + to).copied().unwrap_or(0.)
105 }
106
107 fn distance_approx(&self, _: &Profile, from: Location, to: Location) -> Distance {
108 self.distances.get(from * self.size + to).copied().unwrap_or(0.)
109 }
110
111 fn duration(&self, route: &Route, from: Location, to: Location, _: TravelTime) -> Duration {
112 self.duration_approx(&route.actor.vehicle.profile, from, to)
113 }
114
115 fn distance(&self, route: &Route, from: Location, to: Location, _: TravelTime) -> Distance {
116 self.distance_approx(&route.actor.vehicle.profile, from, to)
117 }
118}
119
120pub struct MatrixData {
122 pub index: usize,
124 pub timestamp: Option<Timestamp>,
126 pub durations: Vec<Duration>,
128 pub distances: Vec<Distance>,
130}
131
132impl MatrixData {
133 pub fn new(index: usize, timestamp: Option<Timestamp>, durations: Vec<Duration>, distances: Vec<Distance>) -> Self {
135 Self { index, timestamp, durations, distances }
136 }
137}
138
139pub trait TransportFallback: Send + Sync {
141 fn duration(&self, profile: &Profile, from: Location, to: Location) -> Duration;
143
144 fn distance(&self, profile: &Profile, from: Location, to: Location) -> Distance;
146}
147
148struct NoFallback;
150
151impl TransportFallback for NoFallback {
152 fn duration(&self, profile: &Profile, from: Location, to: Location) -> Duration {
153 panic!("cannot get duration for {from}->{to} for {profile:?}")
154 }
155
156 fn distance(&self, profile: &Profile, from: Location, to: Location) -> Distance {
157 panic!("cannot get distance for {from}->{to} for {profile:?}")
158 }
159}
160
161pub fn create_matrix_transport_cost(costs: Vec<MatrixData>) -> GenericResult<Arc<dyn TransportCost>> {
164 create_matrix_transport_cost_with_fallback(costs, NoFallback)
165}
166
167pub fn create_matrix_transport_cost_with_fallback<T: TransportFallback + 'static>(
170 costs: Vec<MatrixData>,
171 fallback: T,
172) -> GenericResult<Arc<dyn TransportCost>> {
173 if costs.is_empty() {
174 return Err("no matrix data found".into());
175 }
176
177 let size = (costs.first().unwrap().durations.len() as Float).sqrt().round() as usize;
178
179 if costs.iter().any(|matrix| matrix.distances.len() != matrix.durations.len()) {
180 return Err("distance and duration collections have different length".into());
181 }
182
183 if costs.iter().any(|matrix| (matrix.distances.len() as Float).sqrt().round() as usize != size) {
184 return Err("distance lengths don't match".into());
185 }
186
187 if costs.iter().any(|matrix| (matrix.durations.len() as Float).sqrt().round() as usize != size) {
188 return Err("duration lengths don't match".into());
189 }
190
191 Ok(if costs.iter().any(|costs| costs.timestamp.is_some()) {
192 Arc::new(TimeAwareMatrixTransportCost::new(costs, size, fallback)?)
193 } else {
194 Arc::new(TimeAgnosticMatrixTransportCost::new(costs, size, fallback)?)
195 })
196}
197
198struct TimeAgnosticMatrixTransportCost<T: TransportFallback> {
200 durations: Vec<Vec<Duration>>,
201 distances: Vec<Vec<Distance>>,
202 size: usize,
203 fallback: T,
204}
205
206impl<T: TransportFallback> TimeAgnosticMatrixTransportCost<T> {
207 pub fn new(costs: Vec<MatrixData>, size: usize, fallback: T) -> Result<Self, GenericError> {
209 let mut costs = costs;
210 costs.sort_by(|a, b| a.index.cmp(&b.index));
211
212 if costs.iter().any(|costs| costs.timestamp.is_some()) {
213 return Err("time aware routing".into());
214 }
215
216 if (0..).zip(costs.iter().map(|c| &c.index)).any(|(a, &b)| a != b) {
217 return Err("duplicate profiles can be passed only for time aware routing".into());
218 }
219
220 let (durations, distances) = costs.into_iter().fold((vec![], vec![]), |mut acc, data| {
221 acc.0.push(data.durations);
222 acc.1.push(data.distances);
223
224 acc
225 });
226
227 Ok(Self { durations, distances, size, fallback })
228 }
229}
230
231impl<T: TransportFallback> TransportCost for TimeAgnosticMatrixTransportCost<T> {
232 fn duration_approx(&self, profile: &Profile, from: Location, to: Location) -> Duration {
233 self.durations
234 .get(profile.index)
235 .unwrap()
236 .get(from * self.size + to)
237 .copied()
238 .unwrap_or_else(|| self.fallback.duration(profile, from, to))
239 * profile.scale
240 }
241
242 fn distance_approx(&self, profile: &Profile, from: Location, to: Location) -> Distance {
243 self.distances
244 .get(profile.index)
245 .unwrap()
246 .get(from * self.size + to)
247 .copied()
248 .unwrap_or_else(|| self.fallback.distance(profile, from, to))
249 }
250
251 fn duration(&self, route: &Route, from: Location, to: Location, _: TravelTime) -> Duration {
252 self.duration_approx(&route.actor.vehicle.profile, from, to)
253 }
254
255 fn distance(&self, route: &Route, from: Location, to: Location, _: TravelTime) -> Distance {
256 self.distance_approx(&route.actor.vehicle.profile, from, to)
257 }
258}
259
260struct TimeAwareMatrixTransportCost<T: TransportFallback> {
262 costs: HashMap<usize, (Vec<u64>, Vec<MatrixData>)>,
263 size: usize,
264 fallback: T,
265}
266
267impl<T: TransportFallback> TimeAwareMatrixTransportCost<T> {
268 fn new(costs: Vec<MatrixData>, size: usize, fallback: T) -> Result<Self, GenericError> {
270 if costs.iter().any(|matrix| matrix.timestamp.is_none()) {
271 return Err("time-aware routing requires all matrices to have timestamp".into());
272 }
273
274 let costs = costs.into_iter().collect_group_by_key(|matrix| matrix.index);
275
276 if costs.iter().any(|(_, matrices)| matrices.len() == 1) {
277 return Err("should not use time aware matrix routing with single matrix".into());
278 }
279
280 let costs = costs
281 .into_iter()
282 .map(|(profile, mut matrices)| {
283 matrices.sort_by(|a, b| (a.timestamp.unwrap() as u64).cmp(&(b.timestamp.unwrap() as u64)));
284 let timestamps = matrices.iter().map(|matrix| matrix.timestamp.unwrap() as u64).collect();
285
286 (profile, (timestamps, matrices))
287 })
288 .collect();
289
290 Ok(Self { costs, size, fallback })
291 }
292
293 fn interpolate_duration(
294 &self,
295 profile: &Profile,
296 from: Location,
297 to: Location,
298 travel_time: TravelTime,
299 ) -> Duration {
300 let timestamp = match travel_time {
301 TravelTime::Arrival(arrival) => arrival,
302 TravelTime::Departure(departure) => departure,
303 };
304
305 let (timestamps, matrices) = self.costs.get(&profile.index).unwrap();
306 let data_idx = from * self.size + to;
307
308 let duration = match timestamps.binary_search(&(timestamp as u64)) {
309 Ok(matrix_idx) => matrices.get(matrix_idx).unwrap().durations.get(data_idx).copied(),
310 Err(0) => matrices.first().unwrap().durations.get(data_idx).copied(),
311 Err(matrix_idx) if matrix_idx == matrices.len() => {
312 matrices.last().unwrap().durations.get(data_idx).copied()
313 }
314 Err(matrix_idx) => {
315 let left_matrix = matrices.get(matrix_idx - 1).unwrap();
316 let right_matrix = matrices.get(matrix_idx).unwrap();
317
318 matrices
319 .get(matrix_idx - 1)
320 .unwrap()
321 .durations
322 .get(data_idx)
323 .zip(matrices.get(matrix_idx).unwrap().durations.get(data_idx))
324 .map(|(&left_value, &right_value)| {
325 let ratio = (timestamp - left_matrix.timestamp.unwrap())
327 / (right_matrix.timestamp.unwrap() - left_matrix.timestamp.unwrap());
328
329 left_value + ratio * (right_value - left_value)
330 })
331 }
332 }
333 .unwrap_or_else(|| self.fallback.duration(profile, from, to));
334
335 duration * profile.scale
336 }
337
338 fn interpolate_distance(
339 &self,
340 profile: &Profile,
341 from: Location,
342 to: Location,
343 travel_time: TravelTime,
344 ) -> Distance {
345 let timestamp = match travel_time {
346 TravelTime::Arrival(arrival) => arrival,
347 TravelTime::Departure(departure) => departure,
348 };
349
350 let (timestamps, matrices) = self.costs.get(&profile.index).unwrap();
351 let data_idx = from * self.size + to;
352
353 match timestamps.binary_search(&(timestamp as u64)) {
354 Ok(matrix_idx) => matrices.get(matrix_idx).unwrap().distances.get(data_idx),
355 Err(0) => matrices.first().unwrap().distances.get(data_idx),
356 Err(matrix_idx) if matrix_idx == matrices.len() => matrices.last().unwrap().distances.get(data_idx),
357 Err(matrix_idx) => matrices.get(matrix_idx - 1).unwrap().distances.get(data_idx),
358 }
359 .copied()
360 .unwrap_or_else(|| self.fallback.distance(profile, from, to))
361 }
362}
363
364impl<T: TransportFallback> TransportCost for TimeAwareMatrixTransportCost<T> {
365 fn duration_approx(&self, profile: &Profile, from: Location, to: Location) -> Duration {
366 self.interpolate_duration(profile, from, to, TravelTime::Departure(0.))
367 }
368
369 fn distance_approx(&self, profile: &Profile, from: Location, to: Location) -> Distance {
370 self.interpolate_distance(profile, from, to, TravelTime::Departure(0.))
371 }
372
373 fn duration(&self, route: &Route, from: Location, to: Location, travel_time: TravelTime) -> Duration {
374 self.interpolate_duration(&route.actor.vehicle.profile, from, to, travel_time)
375 }
376
377 fn distance(&self, route: &Route, from: Location, to: Location, travel_time: TravelTime) -> Distance {
378 self.interpolate_distance(&route.actor.vehicle.profile, from, to, travel_time)
379 }
380}