1use std::{
2 cmp::max,
3 sync::{atomic::AtomicI32, Arc},
4};
5
6use itertools::Itertools;
7use log::debug;
8
9use crate::{
10 dinic::Dinic,
11 edge::{InputEdge, TrivialEdge},
12 geometry::primitives::FPCoordinate,
13 max_flow::{MaxFlow, ResidualEdgeData},
14 renumbering_table::RenumberingTable,
15};
16
17const ROTATED_COMPARATORS: [fn(i32, i32) -> i32; 4] = [
18 |lat, _ln| -> i32 { lat },
19 |_lt, lon| -> i32 { lon },
20 |lat, lon| -> i32 { lon + lat },
21 |lat, lon| -> i32 { -lon + lat },
22];
23
24#[derive(Debug)]
25pub enum FlowError {
26 AxisOutOfBounds,
27 String(String),
28}
29
30#[derive(Clone, Debug, PartialEq)]
31pub struct Flow {
32 pub flow: i32,
33 pub balance: f64,
34 pub left_ids: Vec<usize>,
35 pub right_ids: Vec<usize>,
36}
37
38pub fn flow_cmp(a: &Flow, b: &Flow) -> std::cmp::Ordering {
39 if a.flow == b.flow {
40 return b.balance.partial_cmp(&a.balance).unwrap();
43 }
44 a.flow.cmp(&b.flow)
45}
46
47pub fn sub_step(
58 input_edges: &[TrivialEdge],
59 node_id_list: &[usize],
60 coordinates: &[FPCoordinate],
61 axis: usize,
62 balance_factor: f64,
63 upper_bound: Arc<AtomicI32>,
64) -> Result<Flow, FlowError> {
65 debug_assert!(axis < 4);
66 debug_assert!(balance_factor > 0.);
67 debug_assert!(balance_factor < 0.5);
68 debug_assert!(coordinates.len() > 2);
69
70 if axis >= 4 {
71 return Err(FlowError::AxisOutOfBounds);
72 }
73
74 let comparator = ROTATED_COMPARATORS[axis];
75 debug!("[{axis}] sorting cooefficient: {:?}", comparator);
76 let mut node_id_list = node_id_list.to_vec();
78 node_id_list
79 .sort_unstable_by_key(|a| -> i32 { comparator(coordinates[*a].lat, coordinates[*a].lon) });
80
81 let size_of_contraction = max(1, (node_id_list.len() as f64 * balance_factor) as usize);
82 let sources = &node_id_list[0..size_of_contraction];
83 let targets = &node_id_list[node_id_list.len() - size_of_contraction..];
84
85 debug_assert!(!sources.is_empty());
86 debug_assert!(!targets.is_empty());
87
88 debug!("[{axis}] renumbering of inertial flow graph");
89 let mut renumbering_table =
91 RenumberingTable::new_with_size_hint(coordinates.len(), node_id_list.len());
92 for s in sources {
96 renumbering_table.set(*s, 0);
97 }
98 for t in targets {
99 renumbering_table.set(*t, 1);
100 }
101
102 let mut edges = input_edges
104 .iter()
105 .map(|edge| -> InputEdge<ResidualEdgeData> {
106 InputEdge::<ResidualEdgeData> {
107 source: edge.source,
108 target: edge.target,
109 data: ResidualEdgeData::new(1),
110 }
111 })
112 .collect_vec();
113 let mut current_id = 2;
114
115 for e in &mut edges {
116 if !renumbering_table.contains_key(e.source) {
118 renumbering_table.set(e.source, current_id);
119 current_id += 1;
120 }
121 if !renumbering_table.contains_key(e.target) {
122 renumbering_table.set(e.target, current_id);
123 current_id += 1;
124 }
125 e.source = renumbering_table.get(e.source);
126 e.target = renumbering_table.get(e.target);
127 }
128 debug!("[{axis}] instantiating min-cut solver, epsilon 0.25");
129
130 let edge_count_before = edges.len();
132 edges.retain(|edge| edge.source != edge.target);
133 debug!(
134 "[{axis}] eigenloop removal - edge count before {edge_count_before}, after {}",
135 edges.len()
136 );
137 edges.shrink_to_fit();
138
139 debug!("[{axis}] instantiating min-cut solver, epsilon {balance_factor}");
140 let mut max_flow_solver = Dinic::from_edge_list(edges, 0, 1);
141 debug!("[{axis}] instantiated min-cut solver");
142 max_flow_solver.run_with_upper_bound(upper_bound);
143
144 let max_flow = max_flow_solver.max_flow();
145
146 if let Err(message) = max_flow {
147 return Err(FlowError::String(message));
149 }
150 let flow = max_flow.expect("max flow computation did not run");
151
152 debug!("[{axis}] computed max flow: {flow}");
153 let intermediate_assignment = max_flow_solver
154 .assignment(0)
155 .expect("max flow computation did not run");
156
157 let (left_ids, right_ids): (Vec<_>, Vec<_>) = node_id_list
159 .into_iter()
160 .filter(|id| renumbering_table.contains_key(*id))
161 .partition(|id| intermediate_assignment[renumbering_table.get(*id)]);
162
163 debug_assert!(!left_ids.is_empty());
164 debug_assert!(!right_ids.is_empty());
165
166 let balance = std::cmp::min(left_ids.len(), right_ids.len()) as f64
167 / (left_ids.len() + right_ids.len()) as f64;
168 debug!("[{axis}] balance: {balance}");
169
170 Ok(Flow {
171 flow,
172 balance,
173 left_ids,
174 right_ids,
175 })
176}
177
178#[cfg(test)]
179mod tests {
180 use itertools::Itertools;
181 use std::sync::{atomic::AtomicI32, Arc};
182
183 use crate::{
184 geometry::primitives::FPCoordinate,
185 inertial_flow::{flow_cmp, sub_step, Flow, TrivialEdge},
186 };
187
188 static EDGES: [TrivialEdge; 14] = [
189 TrivialEdge {
190 source: 0,
191 target: 1,
192 },
193 TrivialEdge {
194 source: 1,
195 target: 0,
196 },
197 TrivialEdge {
198 source: 0,
199 target: 2,
200 },
201 TrivialEdge {
202 source: 2,
203 target: 0,
204 },
205 TrivialEdge {
206 source: 1,
207 target: 2,
208 },
209 TrivialEdge {
210 source: 2,
211 target: 1,
212 },
213 TrivialEdge {
214 source: 2,
215 target: 4,
216 },
217 TrivialEdge {
218 source: 4,
219 target: 2,
220 },
221 TrivialEdge {
222 source: 3,
223 target: 5,
224 },
225 TrivialEdge {
226 source: 5,
227 target: 3,
228 },
229 TrivialEdge {
230 source: 4,
231 target: 3,
232 },
233 TrivialEdge {
234 source: 3,
235 target: 4,
236 },
237 TrivialEdge {
238 source: 4,
239 target: 5,
240 },
241 TrivialEdge {
242 source: 5,
243 target: 4,
244 },
245 ];
246
247 static COORDINATES: [FPCoordinate; 6] = [
248 FPCoordinate::new(1, 0),
249 FPCoordinate::new(2, 1),
250 FPCoordinate::new(0, 1),
251 FPCoordinate::new(2, 2),
252 FPCoordinate::new(0, 2),
253 FPCoordinate::new(1, 3),
254 ];
255 static NODE_ID_LIST: [usize; 6] = [0, 1, 2, 3, 4, 5];
256
257 #[test]
258 fn inertial_flow() {
259 let upper_bound = Arc::new(AtomicI32::new(6));
260 let result = sub_step(&EDGES, &NODE_ID_LIST, &COORDINATES, 3, 0.25, upper_bound)
261 .expect("error should not happen");
262 assert_eq!(result.flow, 1);
263 assert_eq!(result.balance, 0.5);
264 assert_eq!(result.left_ids.len(), 3);
265 assert_eq!(result.left_ids, vec![4, 5, 3]);
266 assert_eq!(result.right_ids.len(), 3);
267 assert_eq!(result.right_ids, vec![2, 0, 1]);
268 }
269
270 #[test]
271 fn inertial_flow_all_indices() {
272 let upper_bound = Arc::new(AtomicI32::new(6));
273 let result = (0..4)
274 .map(|axis| -> Result<_, _> {
275 sub_step(
276 &EDGES,
277 &NODE_ID_LIST,
278 &COORDINATES,
279 axis,
280 0.25,
281 upper_bound.clone(),
282 )
283 })
284 .collect_vec();
285 assert_eq!(result.len(), 4);
286
287 for r in &result {
288 let r = r.as_ref().clone().expect("error should not happen");
289 assert_eq!(r.flow, 1);
290 assert_eq!(r.balance, 0.5);
291 assert_eq!(r.left_ids.len(), 3);
292 assert_eq!(r.right_ids.len(), 3);
293 }
294
295 let min_max = result.into_iter().map(|r| r.unwrap()).minmax_by(flow_cmp);
296 let (min, max) = min_max.into_option().expect("minmax failed");
297 assert_eq!(
298 min,
299 Flow {
300 flow: 1,
301 balance: 0.5,
302 left_ids: vec![2, 0, 1],
303 right_ids: vec![4, 5, 3]
304 }
305 );
306 assert_eq!(
307 max,
308 Flow {
309 flow: 1,
310 balance: 0.5,
311 left_ids: vec![4, 5, 3],
312 right_ids: vec![2, 0, 1]
313 }
314 );
315 }
316}