Skip to main content

text_block_permutation_optimizer/
algo.rs

1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::levenshtein;
6use rayon::prelude::*;
7use tracing::info;
8
9// Levensthein Distance,
10// how strong should shuffling be penalized.
11const JUMP_STATIC_PENALTY: f64 = 2.0;
12const JUMP_OFFSET_PENALTY: f64 = 0.2;
13
14// Bigger blocks are slower, but enhance quality of output
15const BLOCK_SIZES: &[usize] = &[1, 2, 3, 4, 6, 9, 12, 15];
16// Base shift window; scaled up proportionally for larger blocks
17const BASE_BLOCK_SHIFT: usize = 30;
18// Scaling factor: max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE
19const SHIFT_PER_BLOCK_SIZE: usize = 8;
20
21type IndexedBlock<'a> = (usize, &'a str);
22
23#[derive(Debug, Clone)]
24pub struct MatchingLoss {
25    pub fuzzy: f64,
26    pub jump_distance: f64,
27    pub geo_distance: f64,
28}
29
30impl MatchingLoss {
31    pub fn abs(&self) -> f64 {
32        self.fuzzy + self.jump_distance + self.geo_distance
33    }
34}
35
36/// Convert a non-negative f64 to u64 preserving ordering.
37/// Positive IEEE 754 floats have the same ordering as their bit patterns.
38#[inline]
39fn f64_to_ord_u64(v: f64) -> u64 {
40    debug_assert!(v >= 0.0);
41    v.to_bits()
42}
43
44#[inline]
45fn ord_u64_to_f64(v: u64) -> f64 {
46    f64::from_bits(v)
47}
48
49/// Core function: Optimize the matching order of a list of strings against ground truth.
50///
51/// # Arguments
52///
53/// * `ground` - The ground truth string to match against
54/// * `blocks` - Slice of strings to be matched/ordered
55///
56/// # Returns
57///
58/// Tuple of (MatchingLoss, Vec of (index, text) tuples representing optimal order)
59pub fn optimize<'a>(ground: &str, blocks: &[&'a str]) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
60    let ground_normalized = normalize_whitespace(ground);
61    let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
62    let ground_concat = ground_words.join(" ");
63
64    // Precompute the BatchComparator for the ground truth (builds pattern match vector once).
65    let scorer = levenshtein::BatchComparator::new(ground_concat.bytes());
66    let ground_len = ground_concat.len();
67
68    // Pre-compute total byte length for buffer allocation
69    let total_block_bytes: usize =
70        blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
71    let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
72
73    // Convert blocks slice to indexed tuples
74    let mut matching: Vec<IndexedBlock<'a>> =
75        blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
76
77    let mut optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
78    info!("Start fuzzy matching with loss of {:?}", optimum_loss);
79
80    let shift_count = AtomicU64::new(0);
81    let start_time = Instant::now();
82    let n = matching.len();
83
84    for &block_size in BLOCK_SIZES.iter().rev() {
85        let level_start = Instant::now();
86        let mut level_improvements = 0u64;
87        let max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE;
88        info!(
89            "Optimizing on block level {} (max_shift={})",
90            block_size, max_shift
91        );
92
93        loop {
94            // Shared atomic cutoff: when any thread finds a better solution,
95            // all other threads immediately use the tighter cutoff for pruning.
96            let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
97            let snapshot = matching.clone();
98
99            let best = (0..n)
100                .into_par_iter()
101                .map_init(
102                    || {
103                        (
104                            snapshot.clone(),
105                            Vec::<u8>::with_capacity(total_block_bytes),
106                        )
107                    },
108                    |(local_matching, buf), start_pos| {
109                        let end_pos = cmp::min(start_pos + block_size, n);
110                        let actual_bs = end_pos - start_pos;
111
112                        if !is_consecutive_block(local_matching, start_pos, actual_bs) {
113                            return None;
114                        }
115
116                        let lo = start_pos.saturating_sub(max_shift);
117                        let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
118
119                        let mut best_local: Option<(f64, usize)> = None;
120                        let mut local_count = 0u64;
121
122                        for target in lo..hi {
123                            if target == start_pos {
124                                continue;
125                            }
126                            local_count += 1;
127
128                            // Read shared cutoff — tightens as other threads find improvements
129                            let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
130
131                            apply_shift(local_matching, start_pos, actual_bs, target);
132
133                            let loss = score_matching(
134                                &scorer,
135                                ground_len,
136                                local_matching,
137                                buf,
138                                Some(cutoff),
139                            );
140
141                            let loss_val = loss.abs();
142                            if loss_val < cutoff {
143                                best_local = Some((loss_val, target));
144                                // Publish tighter cutoff to all threads
145                                shared_cutoff
146                                    .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
147                            }
148
149                            undo_shift(local_matching, start_pos, actual_bs, target);
150                        }
151
152                        shift_count.fetch_add(local_count, Ordering::Relaxed);
153                        best_local.map(|(loss, target)| (loss, start_pos, target))
154                    },
155                )
156                .flatten()
157                .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
158
159            if let Some((_best_loss, best_start, best_target)) = best {
160                let end_pos = cmp::min(best_start + block_size, n);
161                let actual_bs = end_pos - best_start;
162                apply_shift(&mut matching, best_start, actual_bs, best_target);
163                let old = optimum_loss.abs();
164                optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
165                level_improvements += 1;
166                info!(
167                    "  bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
168                    block_size,
169                    old,
170                    optimum_loss.abs(),
171                    shift_count.load(Ordering::Relaxed),
172                    start_time.elapsed()
173                );
174            } else {
175                break;
176            }
177
178            if shift_count.load(Ordering::Relaxed) > 4_000_000 {
179                break;
180            }
181        }
182        info!(
183            "  level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
184            block_size,
185            level_improvements,
186            level_start.elapsed(),
187            start_time.elapsed()
188        );
189    }
190
191    info!(
192        "Returning optimum with {:?} after {} trials ({:.2?})",
193        optimum_loss,
194        shift_count.load(Ordering::Relaxed),
195        start_time.elapsed()
196    );
197    (optimum_loss, matching)
198}
199
200/// Check that the original indexes in a block form a consecutive ascending sequence.
201#[inline]
202fn is_consecutive_block(matching: &[IndexedBlock], start: usize, len: usize) -> bool {
203    for i in 1..len {
204        if matching[start + i].0 != matching[start + i - 1].0 + 1 {
205            return false;
206        }
207    }
208    true
209}
210
211/// Shift block at `[src..src+len]` to position `dst` in-place.
212#[inline]
213fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
214    if dst < src {
215        slice[dst..src + len].rotate_right(len);
216    } else {
217        slice[src..dst + len].rotate_left(len);
218    }
219}
220
221/// Undo a shift (reverse of apply_shift)
222#[inline]
223fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
224    if dst < src {
225        slice[dst..src + len].rotate_left(len);
226    } else {
227        slice[src..dst + len].rotate_right(len);
228    }
229}
230
231/// Score a candidate matching against the ground truth.
232fn score_matching(
233    scorer: &levenshtein::BatchComparator<u8>,
234    ground_len: usize,
235    matching: &[IndexedBlock],
236    buf: &mut Vec<u8>,
237    best_known: Option<f64>,
238) -> MatchingLoss {
239    // Compute jump distance first (very cheap pre-filter).
240    let jump_distance = calculate_jump_distance(matching);
241    if let Some(best) = best_known {
242        if jump_distance >= best {
243            return MatchingLoss {
244                fuzzy: f64::MAX,
245                jump_distance,
246                geo_distance: 0.0,
247            };
248        }
249    }
250
251    // Build concatenated bytes into reusable buffer
252    buf.clear();
253    for (i, (_, text)) in matching.iter().enumerate() {
254        if i > 0 {
255            buf.push(b' ');
256        }
257        buf.extend_from_slice(text.as_bytes());
258    }
259
260    // Use score_cutoff for early termination in levenshtein computation
261    let lev_dist = if let Some(best) = best_known {
262        let max_fuzzy = (best - jump_distance).max(0.0) as usize;
263        let args = levenshtein::Args::default()
264            .score_hint(max_fuzzy)
265            .score_cutoff(max_fuzzy);
266        match scorer.distance_with_args(buf.iter().copied(), &args) {
267            Some(d) => d as f64,
268            None => {
269                return MatchingLoss {
270                    fuzzy: (ground_len + buf.len()) as f64,
271                    jump_distance,
272                    geo_distance: 0.0,
273                };
274            }
275        }
276    } else {
277        scorer.distance(buf.iter().copied()) as f64
278    };
279
280    MatchingLoss {
281        fuzzy: lev_dist,
282        jump_distance,
283        geo_distance: 0.0,
284    }
285}
286
287fn calculate_jump_distance(matching: &[IndexedBlock]) -> f64 {
288    let mut j = 0.0;
289    for i in 1..matching.len() {
290        let left = matching[i - 1];
291        let right = matching[i];
292        let expected_right = left.0 + 1;
293        let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
294        if distance == 0 {
295            continue;
296        }
297        j += JUMP_STATIC_PENALTY + (distance as f64) * JUMP_OFFSET_PENALTY;
298    }
299    j
300}
301
302fn normalize_whitespace(text: &str) -> String {
303    text.split_whitespace().collect::<Vec<_>>().join(" ")
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_levenshtein_distance() {
312        assert_eq!(levenshtein::distance("abc".chars(), "abc".chars()), 0);
313        assert_eq!(levenshtein::distance("abc".chars(), "ab".chars()), 1);
314        assert_eq!(levenshtein::distance("ab".chars(), "abc".chars()), 1);
315        assert_eq!(levenshtein::distance("abc".chars(), "def".chars()), 3);
316        assert_eq!(levenshtein::distance("".chars(), "abc".chars()), 3);
317        assert_eq!(levenshtein::distance("abc".chars(), "".chars()), 3);
318    }
319
320    #[test]
321    fn test_normalize_whitespace() {
322        assert_eq!(normalize_whitespace("  hello   world  "), "hello world");
323        assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
324    }
325
326    #[test]
327    fn test_optimize_simple() {
328        let ground = "hello world";
329        let blocks = &["world", "hello"];
330        let (loss, result) = optimize(ground, blocks);
331        assert!(loss.abs() >= 0.0);
332        assert_eq!(result.len(), 2);
333    }
334
335    #[test]
336    fn test_apply_undo_shift() {
337        let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
338        apply_shift(&mut v, 2, 2, 5);
339        assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
340        undo_shift(&mut v, 2, 2, 5);
341        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
342
343        apply_shift(&mut v, 5, 2, 1);
344        assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
345        undo_shift(&mut v, 5, 2, 1);
346        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
347    }
348
349    #[test]
350    fn test_f64_ord_roundtrip() {
351        let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
352        for &v in &vals {
353            assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
354        }
355        assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
356        assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
357    }
358}