1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::levenshtein;
6use rayon::prelude::*;
7use tracing::info;
8
9#[derive(Debug, Clone)]
11pub struct AlgoConfig {
12 pub jump_static_penalty: f64,
15
16 pub jump_offset_penalty: f64,
20
21 pub block_sizes: Vec<usize>,
25
26 pub block_step_factor: usize,
30
31 pub base_block_shift: usize,
34
35 pub shift_per_block_size: usize,
39}
40
41impl Default for AlgoConfig {
42 fn default() -> Self {
43 Self {
44 jump_static_penalty: 2.0,
45 jump_offset_penalty: 0.2,
46 block_sizes: vec![1, 2, 3, 4, 6, 9, 12, 15],
47 block_step_factor: 5,
48 base_block_shift: 30,
49 shift_per_block_size: 8,
50 }
51 }
52}
53
54type IndexedBlock<'a> = (usize, &'a str);
55
56#[derive(Debug, Clone)]
57pub struct MatchingLoss {
58 pub fuzzy: f64,
59 pub jump_distance: f64,
60 pub geo_distance: f64,
61}
62
63impl MatchingLoss {
64 pub fn abs(&self) -> f64 {
65 self.fuzzy + self.jump_distance + self.geo_distance
66 }
67}
68
69#[inline]
72fn f64_to_ord_u64(v: f64) -> u64 {
73 debug_assert!(v >= 0.0);
74 v.to_bits()
75}
76
77#[inline]
78fn ord_u64_to_f64(v: u64) -> f64 {
79 f64::from_bits(v)
80}
81
82pub fn optimize<'a>(
94 ground: &str,
95 blocks: &[&'a str],
96 config: &AlgoConfig,
97) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
98 let ground_normalized = normalize_whitespace(ground);
99 let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
100 let ground_concat = ground_words.join(" ");
101
102 let scorer = levenshtein::BatchComparator::new(ground_concat.bytes());
104 let ground_len = ground_concat.len();
105
106 let total_block_bytes: usize =
108 blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
109 let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
110
111 let mut matching: Vec<IndexedBlock<'a>> =
113 blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
114
115 let mut optimum_loss =
116 score_matching(&scorer, ground_len, &matching, &mut main_buf, None, config);
117 info!("Start fuzzy matching with loss of {:?}", optimum_loss);
118
119 let shift_count = AtomicU64::new(0);
120 let start_time = Instant::now();
121 let n = matching.len();
122
123 for &block_size in config.block_sizes.iter().rev() {
124 let level_start = Instant::now();
125 let mut level_improvements = 0u64;
126 let max_shift = config.base_block_shift + block_size * config.shift_per_block_size;
127 let step_size: usize = block_size.div_ceil(config.block_step_factor);
128 info!(
129 "Optimizing on block level {} (max_shift={}, step_size={})",
130 block_size, max_shift, step_size
131 );
132
133 loop {
134 let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
137 let snapshot = matching.clone();
138
139 let best = (0..n)
140 .into_par_iter()
141 .map_init(
142 || {
143 (
144 snapshot.clone(),
145 Vec::<u8>::with_capacity(total_block_bytes),
146 )
147 },
148 |(local_matching, buf), start_pos| {
149 let end_pos = cmp::min(start_pos + block_size, n);
150 let actual_bs = end_pos - start_pos;
151
152 if !is_consecutive_block(local_matching, start_pos, actual_bs) {
153 return None;
154 }
155
156 let lo = start_pos.saturating_sub(max_shift);
157 let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
158
159 let mut best_local: Option<(f64, usize)> = None;
160 let mut local_count = 0u64;
161
162 for target in (lo..hi).step_by(step_size) {
163 if target == start_pos {
164 continue;
165 }
166 local_count += 1;
167
168 let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
170
171 apply_shift(local_matching, start_pos, actual_bs, target);
172
173 let loss = score_matching(
174 &scorer,
175 ground_len,
176 local_matching,
177 buf,
178 Some(cutoff),
179 config,
180 );
181
182 let loss_val = loss.abs();
183 if loss_val < cutoff {
184 best_local = Some((loss_val, target));
185 shared_cutoff
187 .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
188 }
189
190 undo_shift(local_matching, start_pos, actual_bs, target);
191 }
192
193 shift_count.fetch_add(local_count, Ordering::Relaxed);
194 best_local.map(|(loss, target)| (loss, start_pos, target))
195 },
196 )
197 .flatten()
198 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
199
200 if let Some((_best_loss, best_start, best_target)) = best {
201 let end_pos = cmp::min(best_start + block_size, n);
202 let actual_bs = end_pos - best_start;
203 apply_shift(&mut matching, best_start, actual_bs, best_target);
204 let old = optimum_loss.abs();
205 optimum_loss =
206 score_matching(&scorer, ground_len, &matching, &mut main_buf, None, config);
207 level_improvements += 1;
208 info!(
209 " bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
210 block_size,
211 old,
212 optimum_loss.abs(),
213 shift_count.load(Ordering::Relaxed),
214 start_time.elapsed()
215 );
216 } else {
217 break;
218 }
219
220 if shift_count.load(Ordering::Relaxed) > 4_000_000 {
221 break;
222 }
223 }
224 info!(
225 " level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
226 block_size,
227 level_improvements,
228 level_start.elapsed(),
229 start_time.elapsed()
230 );
231 }
232
233 info!(
234 "Returning optimum with {:?} after {} trials ({:.2?})",
235 optimum_loss,
236 shift_count.load(Ordering::Relaxed),
237 start_time.elapsed()
238 );
239 (optimum_loss, matching)
240}
241
242#[inline]
244fn is_consecutive_block(matching: &[IndexedBlock], start: usize, len: usize) -> bool {
245 for i in 1..len {
246 if matching[start + i].0 != matching[start + i - 1].0 + 1 {
247 return false;
248 }
249 }
250 true
251}
252
253#[inline]
255fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
256 if dst < src {
257 slice[dst..src + len].rotate_right(len);
258 } else {
259 slice[src..dst + len].rotate_left(len);
260 }
261}
262
263#[inline]
265fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
266 if dst < src {
267 slice[dst..src + len].rotate_left(len);
268 } else {
269 slice[src..dst + len].rotate_right(len);
270 }
271}
272
273fn score_matching(
275 scorer: &levenshtein::BatchComparator<u8>,
276 ground_len: usize,
277 matching: &[IndexedBlock],
278 buf: &mut Vec<u8>,
279 best_known: Option<f64>,
280 config: &AlgoConfig,
281) -> MatchingLoss {
282 let jump_distance = calculate_jump_distance(matching, config);
284 if let Some(best) = best_known {
285 if jump_distance >= best {
286 return MatchingLoss {
287 fuzzy: f64::MAX,
288 jump_distance,
289 geo_distance: 0.0,
290 };
291 }
292 }
293
294 buf.clear();
296 for (i, (_, text)) in matching.iter().enumerate() {
297 if i > 0 {
298 buf.push(b' ');
299 }
300 buf.extend_from_slice(text.as_bytes());
301 }
302
303 let lev_dist = if let Some(best) = best_known {
305 let max_fuzzy = (best - jump_distance).max(0.0) as usize;
306 let args = levenshtein::Args::default()
307 .score_hint(max_fuzzy)
308 .score_cutoff(max_fuzzy);
309 match scorer.distance_with_args(buf.iter().copied(), &args) {
310 Some(d) => d as f64,
311 None => {
312 return MatchingLoss {
313 fuzzy: (ground_len + buf.len()) as f64,
314 jump_distance,
315 geo_distance: 0.0,
316 };
317 }
318 }
319 } else {
320 scorer.distance(buf.iter().copied()) as f64
321 };
322
323 MatchingLoss {
324 fuzzy: lev_dist,
325 jump_distance,
326 geo_distance: 0.0,
327 }
328}
329
330fn calculate_jump_distance(matching: &[IndexedBlock], config: &AlgoConfig) -> f64 {
331 let mut j = 0.0;
332 for i in 1..matching.len() {
333 let left = matching[i - 1];
334 let right = matching[i];
335 let expected_right = left.0 + 1;
336 let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
337 if distance == 0 {
338 continue;
339 }
340 j += config.jump_static_penalty + (distance as f64) * config.jump_offset_penalty;
341 }
342 j
343}
344
345fn normalize_whitespace(text: &str) -> String {
346 text.split_whitespace().collect::<Vec<_>>().join(" ")
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_levenshtein_distance() {
355 assert_eq!(levenshtein::distance("abc".chars(), "abc".chars()), 0);
356 assert_eq!(levenshtein::distance("abc".chars(), "ab".chars()), 1);
357 assert_eq!(levenshtein::distance("ab".chars(), "abc".chars()), 1);
358 assert_eq!(levenshtein::distance("abc".chars(), "def".chars()), 3);
359 assert_eq!(levenshtein::distance("".chars(), "abc".chars()), 3);
360 assert_eq!(levenshtein::distance("abc".chars(), "".chars()), 3);
361 }
362
363 #[test]
364 fn test_normalize_whitespace() {
365 assert_eq!(normalize_whitespace(" hello world "), "hello world");
366 assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
367 }
368
369 #[test]
370 fn test_optimize_simple() {
371 let ground = "hello world";
372 let blocks = &["world", "hello"];
373 let (loss, result) = optimize(ground, blocks, &AlgoConfig::default());
374 assert!(loss.abs() >= 0.0);
375 assert_eq!(result.len(), 2);
376 }
377
378 #[test]
379 fn test_apply_undo_shift() {
380 let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
381 apply_shift(&mut v, 2, 2, 5);
382 assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
383 undo_shift(&mut v, 2, 2, 5);
384 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
385
386 apply_shift(&mut v, 5, 2, 1);
387 assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
388 undo_shift(&mut v, 5, 2, 1);
389 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
390 }
391
392 #[test]
393 fn test_f64_ord_roundtrip() {
394 let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
395 for &v in &vals {
396 assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
397 }
398 assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
399 assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
400 }
401}