Skip to main content

text_block_permutation_optimizer/
algo.rs

1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::indel;
6use rayon::prelude::*;
7use tracing::info;
8
9// Levensthein Distance,
10// how strong should shuffling be penalized.
11const JUMP_STATIC_PENALTY: f64 = 2.0;
12const JUMP_OFFSET_PENALTY: f64 = 0.2;
13
14// Bigger blocks are slower, but enhance quality of output
15const BLOCK_SIZES: &[usize] = &[1, 2, 3, 4, 6, 9, 12, 15];
16// Base shift window; scaled up proportionally for larger blocks
17const BASE_BLOCK_SHIFT: usize = 30;
18// Scaling factor: max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE
19const SHIFT_PER_BLOCK_SIZE: usize = 8;
20
21type IndexedBlock<'a> = (usize, &'a str);
22
23#[derive(Debug, Clone)]
24pub struct MatchingLoss {
25    pub fuzzy: f64,
26    pub jump_distance: f64,
27    pub geo_distance: f64,
28}
29
30impl MatchingLoss {
31    pub fn abs(&self) -> f64 {
32        self.fuzzy + self.jump_distance + self.geo_distance
33    }
34}
35
36/// Convert a non-negative f64 to u64 preserving ordering.
37/// Positive IEEE 754 floats have the same ordering as their bit patterns.
38#[inline]
39fn f64_to_ord_u64(v: f64) -> u64 {
40    debug_assert!(v >= 0.0);
41    v.to_bits()
42}
43
44#[inline]
45fn ord_u64_to_f64(v: u64) -> f64 {
46    f64::from_bits(v)
47}
48
49/// Core function: Optimize the matching order of a list of strings against ground truth.
50///
51/// # Arguments
52///
53/// * `ground` - The ground truth string to match against
54/// * `blocks` - Slice of strings to be matched/ordered
55///
56/// # Returns
57///
58/// Tuple of (MatchingLoss, Vec of (index, text) tuples representing optimal order)
59pub fn optimize<'a>(ground: &str, blocks: &[&'a str]) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
60    let ground_normalized = normalize_whitespace(ground);
61    let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
62    let ground_concat = ground_words.join(" ");
63
64    // Precompute the BatchComparator for the ground truth (builds pattern match vector once).
65    let scorer = indel::BatchComparator::new(ground_concat.bytes());
66    let ground_len = ground_concat.len();
67
68    // Pre-compute total byte length for buffer allocation
69    let total_block_bytes: usize =
70        blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
71    let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
72
73    // Convert blocks slice to indexed tuples
74    let mut matching: Vec<IndexedBlock<'a>> =
75        blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
76
77    let mut optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
78    info!("Start fuzzy matching with loss of {:?}", optimum_loss);
79
80    let shift_count = AtomicU64::new(0);
81    let start_time = Instant::now();
82    let n = matching.len();
83
84    for &block_size in BLOCK_SIZES.iter().rev() {
85        let level_start = Instant::now();
86        let mut level_improvements = 0u64;
87        let max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE;
88        info!(
89            "Optimizing on block level {} (max_shift={})",
90            block_size, max_shift
91        );
92
93        loop {
94            // Shared atomic cutoff: when any thread finds a better solution,
95            // all other threads immediately use the tighter cutoff for pruning.
96            let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
97            let snapshot = matching.clone();
98
99            let best = (0..n)
100                .into_par_iter()
101                .map_init(
102                    || {
103                        (
104                            snapshot.clone(),
105                            Vec::<u8>::with_capacity(total_block_bytes),
106                        )
107                    },
108                    |(local_matching, buf), start_pos| {
109                        let end_pos = cmp::min(start_pos + block_size, n);
110                        let actual_bs = end_pos - start_pos;
111                        let lo = start_pos.saturating_sub(max_shift);
112                        let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
113
114                        let mut best_local: Option<(f64, usize)> = None;
115                        let mut local_count = 0u64;
116
117                        for target in lo..hi {
118                            if target == start_pos {
119                                continue;
120                            }
121                            local_count += 1;
122
123                            // Read shared cutoff — tightens as other threads find improvements
124                            let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
125
126                            apply_shift(local_matching, start_pos, actual_bs, target);
127
128                            let loss = score_matching(
129                                &scorer,
130                                ground_len,
131                                local_matching,
132                                buf,
133                                Some(cutoff),
134                            );
135
136                            let loss_val = loss.abs();
137                            if loss_val < cutoff {
138                                best_local = Some((loss_val, target));
139                                // Publish tighter cutoff to all threads
140                                shared_cutoff
141                                    .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
142                            }
143
144                            undo_shift(local_matching, start_pos, actual_bs, target);
145                        }
146
147                        shift_count.fetch_add(local_count, Ordering::Relaxed);
148                        best_local.map(|(loss, target)| (loss, start_pos, target))
149                    },
150                )
151                .flatten()
152                .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
153
154            if let Some((_best_loss, best_start, best_target)) = best {
155                let end_pos = cmp::min(best_start + block_size, n);
156                let actual_bs = end_pos - best_start;
157                apply_shift(&mut matching, best_start, actual_bs, best_target);
158                let old = optimum_loss.abs();
159                optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
160                level_improvements += 1;
161                info!(
162                    "  bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
163                    block_size,
164                    old,
165                    optimum_loss.abs(),
166                    shift_count.load(Ordering::Relaxed),
167                    start_time.elapsed()
168                );
169            } else {
170                break;
171            }
172
173            if shift_count.load(Ordering::Relaxed) > 4_000_000 {
174                break;
175            }
176        }
177        info!(
178            "  level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
179            block_size,
180            level_improvements,
181            level_start.elapsed(),
182            start_time.elapsed()
183        );
184    }
185
186    info!(
187        "Returning optimum with {:?} after {} trials ({:.2?})",
188        optimum_loss,
189        shift_count.load(Ordering::Relaxed),
190        start_time.elapsed()
191    );
192    (optimum_loss, matching)
193}
194
195/// Shift block at `[src..src+len]` to position `dst` in-place.
196#[inline]
197fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
198    if dst < src {
199        slice[dst..src + len].rotate_right(len);
200    } else {
201        slice[src..dst + len].rotate_left(len);
202    }
203}
204
205/// Undo a shift (reverse of apply_shift)
206#[inline]
207fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
208    if dst < src {
209        slice[dst..src + len].rotate_left(len);
210    } else {
211        slice[src..dst + len].rotate_right(len);
212    }
213}
214
215/// Score a candidate matching against the ground truth.
216fn score_matching(
217    scorer: &indel::BatchComparator<u8>,
218    ground_len: usize,
219    matching: &[IndexedBlock],
220    buf: &mut Vec<u8>,
221    best_known: Option<f64>,
222) -> MatchingLoss {
223    // Compute jump distance first (very cheap pre-filter).
224    let jump_distance = calculate_jump_distance(matching);
225    if let Some(best) = best_known {
226        if jump_distance >= best {
227            return MatchingLoss {
228                fuzzy: f64::MAX,
229                jump_distance,
230                geo_distance: 0.0,
231            };
232        }
233    }
234
235    // Build concatenated bytes into reusable buffer
236    buf.clear();
237    for (i, (_, text)) in matching.iter().enumerate() {
238        if i > 0 {
239            buf.push(b' ');
240        }
241        buf.extend_from_slice(text.as_bytes());
242    }
243
244    // Use score_cutoff for early termination in indel computation
245    let indel_dist = if let Some(best) = best_known {
246        let max_fuzzy = (best - jump_distance).max(0.0) as usize;
247        let args = indel::Args::default()
248            .score_hint(max_fuzzy)
249            .score_cutoff(max_fuzzy);
250        match scorer.distance_with_args(buf.iter().copied(), &args) {
251            Some(d) => d as f64,
252            None => {
253                return MatchingLoss {
254                    fuzzy: (ground_len + buf.len()) as f64,
255                    jump_distance,
256                    geo_distance: 0.0,
257                };
258            }
259        }
260    } else {
261        scorer.distance(buf.iter().copied()) as f64
262    };
263
264    MatchingLoss {
265        fuzzy: indel_dist,
266        jump_distance,
267        geo_distance: 0.0,
268    }
269}
270
271fn calculate_jump_distance(matching: &[IndexedBlock]) -> f64 {
272    let mut j = 0.0;
273    for i in 1..matching.len() {
274        let left = matching[i - 1];
275        let right = matching[i];
276        let expected_right = left.0 + 1;
277        let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
278        if distance == 0 {
279            continue;
280        }
281        j += JUMP_STATIC_PENALTY + (distance as f64) * JUMP_OFFSET_PENALTY;
282    }
283    j
284}
285
286fn normalize_whitespace(text: &str) -> String {
287    text.split_whitespace().collect::<Vec<_>>().join(" ")
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_indel_distance() {
296        assert_eq!(indel::distance("abc".chars(), "abc".chars()), 0);
297        assert_eq!(indel::distance("abc".chars(), "ab".chars()), 1);
298        assert_eq!(indel::distance("ab".chars(), "abc".chars()), 1);
299        assert_eq!(indel::distance("abc".chars(), "def".chars()), 6);
300        assert_eq!(indel::distance("".chars(), "abc".chars()), 3);
301        assert_eq!(indel::distance("abc".chars(), "".chars()), 3);
302    }
303
304    #[test]
305    fn test_normalize_whitespace() {
306        assert_eq!(normalize_whitespace("  hello   world  "), "hello world");
307        assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
308    }
309
310    #[test]
311    fn test_optimize_simple() {
312        let ground = "hello world";
313        let blocks = &["world", "hello"];
314        let (loss, result) = optimize(ground, blocks);
315        assert!(loss.abs() >= 0.0);
316        assert_eq!(result.len(), 2);
317    }
318
319    #[test]
320    fn test_apply_undo_shift() {
321        let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
322        apply_shift(&mut v, 2, 2, 5);
323        assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
324        undo_shift(&mut v, 2, 2, 5);
325        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
326
327        apply_shift(&mut v, 5, 2, 1);
328        assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
329        undo_shift(&mut v, 5, 2, 1);
330        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
331    }
332
333    #[test]
334    fn test_f64_ord_roundtrip() {
335        let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
336        for &v in &vals {
337            assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
338        }
339        assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
340        assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
341    }
342}