Skip to main content

text_block_permutation_optimizer/
algo.rs

1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::levenshtein;
6use rayon::prelude::*;
7use tracing::info;
8
9/// Configuration for the fuzzy-match block-permutation optimizer.
10#[derive(Debug, Clone)]
11pub struct AlgoConfig {
12    /// Fixed penalty added for every out-of-order adjacency between blocks,
13    /// regardless of how far apart the blocks originally were.
14    pub jump_static_penalty: f64,
15
16    /// Per-unit penalty proportional to the index distance of an out-of-order
17    /// adjacency. The total jump cost for one gap is
18    /// `jump_static_penalty + distance * jump_offset_penalty`.
19    pub jump_offset_penalty: f64,
20
21    /// Set of block sizes to try, from coarse to fine.
22    /// Larger blocks capture structural moves; smaller blocks refine detail.
23    /// Processed in descending order internally.
24    pub block_sizes: Vec<usize>,
25
26    /// Divisor that controls candidate step size for each block level:
27    /// `step_size = block_size.div_ceil(block_step_factor)`.
28    /// Higher values yield a denser (slower, more thorough) search.
29    pub block_step_factor: usize,
30
31    /// Minimum shift window radius (in positions) around a block's current
32    /// location. Limits how far a single move can relocate a block.
33    pub base_block_shift: usize,
34
35    /// The shift window grows linearly with block size:
36    /// `max_shift = base_block_shift + block_size * shift_per_block_size`.
37    /// Larger blocks may need to travel farther to find their correct position.
38    pub shift_per_block_size: usize,
39}
40
41impl Default for AlgoConfig {
42    fn default() -> Self {
43        Self {
44            jump_static_penalty: 2.0,
45            jump_offset_penalty: 0.2,
46            block_sizes: vec![1, 2, 3, 4, 6, 9, 12, 15],
47            block_step_factor: 5,
48            base_block_shift: 30,
49            shift_per_block_size: 8,
50        }
51    }
52}
53
54type IndexedBlock<'a> = (usize, &'a str);
55
56#[derive(Debug, Clone)]
57pub struct MatchingLoss {
58    pub fuzzy: f64,
59    pub jump_distance: f64,
60    pub geo_distance: f64,
61}
62
63impl MatchingLoss {
64    pub fn abs(&self) -> f64 {
65        self.fuzzy + self.jump_distance + self.geo_distance
66    }
67}
68
69/// Convert a non-negative f64 to u64 preserving ordering.
70/// Positive IEEE 754 floats have the same ordering as their bit patterns.
71#[inline]
72fn f64_to_ord_u64(v: f64) -> u64 {
73    debug_assert!(v >= 0.0);
74    v.to_bits()
75}
76
77#[inline]
78fn ord_u64_to_f64(v: u64) -> f64 {
79    f64::from_bits(v)
80}
81
82/// Core function: Optimize the matching order of a list of strings against ground truth.
83///
84/// # Arguments
85///
86/// * `ground` - The ground truth string to match against
87/// * `blocks` - Slice of strings to be matched/ordered
88/// * `config` - Parameters for the algorithm
89///
90/// # Returns
91///
92/// Tuple of (MatchingLoss, Vec of (index, text) tuples representing optimal order)
93pub fn optimize<'a>(
94    ground: &str,
95    blocks: &[&'a str],
96    config: &AlgoConfig,
97) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
98    let ground_normalized = normalize_whitespace(ground);
99    let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
100    let ground_concat = ground_words.join(" ");
101
102    // Precompute the BatchComparator for the ground truth (builds pattern match vector once).
103    let scorer = levenshtein::BatchComparator::new(ground_concat.bytes());
104    let ground_len = ground_concat.len();
105
106    // Pre-compute total byte length for buffer allocation
107    let total_block_bytes: usize =
108        blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
109    let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
110
111    // Convert blocks slice to indexed tuples
112    let mut matching: Vec<IndexedBlock<'a>> =
113        blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
114
115    let mut optimum_loss =
116        score_matching(&scorer, ground_len, &matching, &mut main_buf, None, config);
117    info!("Start fuzzy matching with loss of {:?}", optimum_loss);
118
119    let shift_count = AtomicU64::new(0);
120    let start_time = Instant::now();
121    let n = matching.len();
122
123    for &block_size in config.block_sizes.iter().rev() {
124        let level_start = Instant::now();
125        let mut level_improvements = 0u64;
126        let max_shift = config.base_block_shift + block_size * config.shift_per_block_size;
127        let step_size: usize = block_size.div_ceil(config.block_step_factor);
128        info!(
129            "Optimizing on block level {} (max_shift={}, step_size={})",
130            block_size, max_shift, step_size
131        );
132
133        loop {
134            // Shared atomic cutoff: when any thread finds a better solution,
135            // all other threads immediately use the tighter cutoff for pruning.
136            let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
137            let snapshot = matching.clone();
138
139            let best = (0..n)
140                .into_par_iter()
141                .map_init(
142                    || {
143                        (
144                            snapshot.clone(),
145                            Vec::<u8>::with_capacity(total_block_bytes),
146                        )
147                    },
148                    |(local_matching, buf), start_pos| {
149                        let end_pos = cmp::min(start_pos + block_size, n);
150                        let actual_bs = end_pos - start_pos;
151
152                        if !is_consecutive_block(local_matching, start_pos, actual_bs) {
153                            return None;
154                        }
155
156                        let lo = start_pos.saturating_sub(max_shift);
157                        let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
158
159                        let mut best_local: Option<(f64, usize)> = None;
160                        let mut local_count = 0u64;
161
162                        for target in (lo..hi).step_by(step_size) {
163                            if target == start_pos {
164                                continue;
165                            }
166                            local_count += 1;
167
168                            // Read shared cutoff — tightens as other threads find improvements
169                            let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
170
171                            apply_shift(local_matching, start_pos, actual_bs, target);
172
173                            let loss = score_matching(
174                                &scorer,
175                                ground_len,
176                                local_matching,
177                                buf,
178                                Some(cutoff),
179                                config,
180                            );
181
182                            let loss_val = loss.abs();
183                            if loss_val < cutoff {
184                                best_local = Some((loss_val, target));
185                                // Publish tighter cutoff to all threads
186                                shared_cutoff
187                                    .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
188                            }
189
190                            undo_shift(local_matching, start_pos, actual_bs, target);
191                        }
192
193                        shift_count.fetch_add(local_count, Ordering::Relaxed);
194                        best_local.map(|(loss, target)| (loss, start_pos, target))
195                    },
196                )
197                .flatten()
198                .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
199
200            if let Some((_best_loss, best_start, best_target)) = best {
201                let end_pos = cmp::min(best_start + block_size, n);
202                let actual_bs = end_pos - best_start;
203                apply_shift(&mut matching, best_start, actual_bs, best_target);
204                let old = optimum_loss.abs();
205                optimum_loss =
206                    score_matching(&scorer, ground_len, &matching, &mut main_buf, None, config);
207                level_improvements += 1;
208                info!(
209                    "  bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
210                    block_size,
211                    old,
212                    optimum_loss.abs(),
213                    shift_count.load(Ordering::Relaxed),
214                    start_time.elapsed()
215                );
216            } else {
217                break;
218            }
219
220            if shift_count.load(Ordering::Relaxed) > 4_000_000 {
221                break;
222            }
223        }
224        info!(
225            "  level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
226            block_size,
227            level_improvements,
228            level_start.elapsed(),
229            start_time.elapsed()
230        );
231    }
232
233    info!(
234        "Returning optimum with {:?} after {} trials ({:.2?})",
235        optimum_loss,
236        shift_count.load(Ordering::Relaxed),
237        start_time.elapsed()
238    );
239    (optimum_loss, matching)
240}
241
242/// Check that the original indexes in a block form a consecutive ascending sequence.
243#[inline]
244fn is_consecutive_block(matching: &[IndexedBlock], start: usize, len: usize) -> bool {
245    for i in 1..len {
246        if matching[start + i].0 != matching[start + i - 1].0 + 1 {
247            return false;
248        }
249    }
250    true
251}
252
253/// Shift block at `[src..src+len]` to position `dst` in-place.
254#[inline]
255fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
256    if dst < src {
257        slice[dst..src + len].rotate_right(len);
258    } else {
259        slice[src..dst + len].rotate_left(len);
260    }
261}
262
263/// Undo a shift (reverse of apply_shift)
264#[inline]
265fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
266    if dst < src {
267        slice[dst..src + len].rotate_left(len);
268    } else {
269        slice[src..dst + len].rotate_right(len);
270    }
271}
272
273/// Score a candidate matching against the ground truth.
274fn score_matching(
275    scorer: &levenshtein::BatchComparator<u8>,
276    ground_len: usize,
277    matching: &[IndexedBlock],
278    buf: &mut Vec<u8>,
279    best_known: Option<f64>,
280    config: &AlgoConfig,
281) -> MatchingLoss {
282    // Compute jump distance first (very cheap pre-filter).
283    let jump_distance = calculate_jump_distance(matching, config);
284    if let Some(best) = best_known {
285        if jump_distance >= best {
286            return MatchingLoss {
287                fuzzy: f64::MAX,
288                jump_distance,
289                geo_distance: 0.0,
290            };
291        }
292    }
293
294    // Build concatenated bytes into reusable buffer
295    buf.clear();
296    for (i, (_, text)) in matching.iter().enumerate() {
297        if i > 0 {
298            buf.push(b' ');
299        }
300        buf.extend_from_slice(text.as_bytes());
301    }
302
303    // Use score_cutoff for early termination in levenshtein computation
304    let lev_dist = if let Some(best) = best_known {
305        let max_fuzzy = (best - jump_distance).max(0.0) as usize;
306        let args = levenshtein::Args::default()
307            .score_hint(max_fuzzy)
308            .score_cutoff(max_fuzzy);
309        match scorer.distance_with_args(buf.iter().copied(), &args) {
310            Some(d) => d as f64,
311            None => {
312                return MatchingLoss {
313                    fuzzy: (ground_len + buf.len()) as f64,
314                    jump_distance,
315                    geo_distance: 0.0,
316                };
317            }
318        }
319    } else {
320        scorer.distance(buf.iter().copied()) as f64
321    };
322
323    MatchingLoss {
324        fuzzy: lev_dist,
325        jump_distance,
326        geo_distance: 0.0,
327    }
328}
329
330fn calculate_jump_distance(matching: &[IndexedBlock], config: &AlgoConfig) -> f64 {
331    let mut j = 0.0;
332    for i in 1..matching.len() {
333        let left = matching[i - 1];
334        let right = matching[i];
335        let expected_right = left.0 + 1;
336        let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
337        if distance == 0 {
338            continue;
339        }
340        j += config.jump_static_penalty + (distance as f64) * config.jump_offset_penalty;
341    }
342    j
343}
344
345fn normalize_whitespace(text: &str) -> String {
346    text.split_whitespace().collect::<Vec<_>>().join(" ")
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_levenshtein_distance() {
355        assert_eq!(levenshtein::distance("abc".chars(), "abc".chars()), 0);
356        assert_eq!(levenshtein::distance("abc".chars(), "ab".chars()), 1);
357        assert_eq!(levenshtein::distance("ab".chars(), "abc".chars()), 1);
358        assert_eq!(levenshtein::distance("abc".chars(), "def".chars()), 3);
359        assert_eq!(levenshtein::distance("".chars(), "abc".chars()), 3);
360        assert_eq!(levenshtein::distance("abc".chars(), "".chars()), 3);
361    }
362
363    #[test]
364    fn test_normalize_whitespace() {
365        assert_eq!(normalize_whitespace("  hello   world  "), "hello world");
366        assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
367    }
368
369    #[test]
370    fn test_optimize_simple() {
371        let ground = "hello world";
372        let blocks = &["world", "hello"];
373        let (loss, result) = optimize(ground, blocks, &AlgoConfig::default());
374        assert!(loss.abs() >= 0.0);
375        assert_eq!(result.len(), 2);
376    }
377
378    #[test]
379    fn test_apply_undo_shift() {
380        let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
381        apply_shift(&mut v, 2, 2, 5);
382        assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
383        undo_shift(&mut v, 2, 2, 5);
384        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
385
386        apply_shift(&mut v, 5, 2, 1);
387        assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
388        undo_shift(&mut v, 5, 2, 1);
389        assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
390    }
391
392    #[test]
393    fn test_f64_ord_roundtrip() {
394        let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
395        for &v in &vals {
396            assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
397        }
398        assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
399        assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
400    }
401}