1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::levenshtein;
6use rayon::prelude::*;
7use tracing::info;
8
9const JUMP_STATIC_PENALTY: f64 = 2.0;
12const JUMP_OFFSET_PENALTY: f64 = 0.2;
13
14const BLOCK_SIZES: &[usize] = &[1, 2, 3, 4, 6, 9, 12, 15];
16const BASE_BLOCK_SHIFT: usize = 30;
18const SHIFT_PER_BLOCK_SIZE: usize = 8;
20
21type IndexedBlock<'a> = (usize, &'a str);
22
23#[derive(Debug, Clone)]
24pub struct MatchingLoss {
25 pub fuzzy: f64,
26 pub jump_distance: f64,
27 pub geo_distance: f64,
28}
29
30impl MatchingLoss {
31 pub fn abs(&self) -> f64 {
32 self.fuzzy + self.jump_distance + self.geo_distance
33 }
34}
35
36#[inline]
39fn f64_to_ord_u64(v: f64) -> u64 {
40 debug_assert!(v >= 0.0);
41 v.to_bits()
42}
43
44#[inline]
45fn ord_u64_to_f64(v: u64) -> f64 {
46 f64::from_bits(v)
47}
48
49pub fn optimize<'a>(ground: &str, blocks: &[&'a str]) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
60 let ground_normalized = normalize_whitespace(ground);
61 let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
62 let ground_concat = ground_words.join(" ");
63
64 let scorer = levenshtein::BatchComparator::new(ground_concat.bytes());
66 let ground_len = ground_concat.len();
67
68 let total_block_bytes: usize =
70 blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
71 let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
72
73 let mut matching: Vec<IndexedBlock<'a>> =
75 blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
76
77 let mut optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
78 info!("Start fuzzy matching with loss of {:?}", optimum_loss);
79
80 let shift_count = AtomicU64::new(0);
81 let start_time = Instant::now();
82 let n = matching.len();
83
84 for &block_size in BLOCK_SIZES.iter().rev() {
85 let level_start = Instant::now();
86 let mut level_improvements = 0u64;
87 let max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE;
88 info!(
89 "Optimizing on block level {} (max_shift={})",
90 block_size, max_shift
91 );
92
93 loop {
94 let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
97 let snapshot = matching.clone();
98
99 let best = (0..n)
100 .into_par_iter()
101 .map_init(
102 || {
103 (
104 snapshot.clone(),
105 Vec::<u8>::with_capacity(total_block_bytes),
106 )
107 },
108 |(local_matching, buf), start_pos| {
109 let end_pos = cmp::min(start_pos + block_size, n);
110 let actual_bs = end_pos - start_pos;
111
112 if !is_consecutive_block(local_matching, start_pos, actual_bs) {
113 return None;
114 }
115
116 let lo = start_pos.saturating_sub(max_shift);
117 let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
118
119 let mut best_local: Option<(f64, usize)> = None;
120 let mut local_count = 0u64;
121
122 for target in lo..hi {
123 if target == start_pos {
124 continue;
125 }
126 local_count += 1;
127
128 let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
130
131 apply_shift(local_matching, start_pos, actual_bs, target);
132
133 let loss = score_matching(
134 &scorer,
135 ground_len,
136 local_matching,
137 buf,
138 Some(cutoff),
139 );
140
141 let loss_val = loss.abs();
142 if loss_val < cutoff {
143 best_local = Some((loss_val, target));
144 shared_cutoff
146 .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
147 }
148
149 undo_shift(local_matching, start_pos, actual_bs, target);
150 }
151
152 shift_count.fetch_add(local_count, Ordering::Relaxed);
153 best_local.map(|(loss, target)| (loss, start_pos, target))
154 },
155 )
156 .flatten()
157 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
158
159 if let Some((_best_loss, best_start, best_target)) = best {
160 let end_pos = cmp::min(best_start + block_size, n);
161 let actual_bs = end_pos - best_start;
162 apply_shift(&mut matching, best_start, actual_bs, best_target);
163 let old = optimum_loss.abs();
164 optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
165 level_improvements += 1;
166 info!(
167 " bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
168 block_size,
169 old,
170 optimum_loss.abs(),
171 shift_count.load(Ordering::Relaxed),
172 start_time.elapsed()
173 );
174 } else {
175 break;
176 }
177
178 if shift_count.load(Ordering::Relaxed) > 4_000_000 {
179 break;
180 }
181 }
182 info!(
183 " level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
184 block_size,
185 level_improvements,
186 level_start.elapsed(),
187 start_time.elapsed()
188 );
189 }
190
191 info!(
192 "Returning optimum with {:?} after {} trials ({:.2?})",
193 optimum_loss,
194 shift_count.load(Ordering::Relaxed),
195 start_time.elapsed()
196 );
197 (optimum_loss, matching)
198}
199
200#[inline]
202fn is_consecutive_block(matching: &[IndexedBlock], start: usize, len: usize) -> bool {
203 for i in 1..len {
204 if matching[start + i].0 != matching[start + i - 1].0 + 1 {
205 return false;
206 }
207 }
208 true
209}
210
211#[inline]
213fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
214 if dst < src {
215 slice[dst..src + len].rotate_right(len);
216 } else {
217 slice[src..dst + len].rotate_left(len);
218 }
219}
220
221#[inline]
223fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
224 if dst < src {
225 slice[dst..src + len].rotate_left(len);
226 } else {
227 slice[src..dst + len].rotate_right(len);
228 }
229}
230
231fn score_matching(
233 scorer: &levenshtein::BatchComparator<u8>,
234 ground_len: usize,
235 matching: &[IndexedBlock],
236 buf: &mut Vec<u8>,
237 best_known: Option<f64>,
238) -> MatchingLoss {
239 let jump_distance = calculate_jump_distance(matching);
241 if let Some(best) = best_known {
242 if jump_distance >= best {
243 return MatchingLoss {
244 fuzzy: f64::MAX,
245 jump_distance,
246 geo_distance: 0.0,
247 };
248 }
249 }
250
251 buf.clear();
253 for (i, (_, text)) in matching.iter().enumerate() {
254 if i > 0 {
255 buf.push(b' ');
256 }
257 buf.extend_from_slice(text.as_bytes());
258 }
259
260 let lev_dist = if let Some(best) = best_known {
262 let max_fuzzy = (best - jump_distance).max(0.0) as usize;
263 let args = levenshtein::Args::default()
264 .score_hint(max_fuzzy)
265 .score_cutoff(max_fuzzy);
266 match scorer.distance_with_args(buf.iter().copied(), &args) {
267 Some(d) => d as f64,
268 None => {
269 return MatchingLoss {
270 fuzzy: (ground_len + buf.len()) as f64,
271 jump_distance,
272 geo_distance: 0.0,
273 };
274 }
275 }
276 } else {
277 scorer.distance(buf.iter().copied()) as f64
278 };
279
280 MatchingLoss {
281 fuzzy: lev_dist,
282 jump_distance,
283 geo_distance: 0.0,
284 }
285}
286
287fn calculate_jump_distance(matching: &[IndexedBlock]) -> f64 {
288 let mut j = 0.0;
289 for i in 1..matching.len() {
290 let left = matching[i - 1];
291 let right = matching[i];
292 let expected_right = left.0 + 1;
293 let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
294 if distance == 0 {
295 continue;
296 }
297 j += JUMP_STATIC_PENALTY + (distance as f64) * JUMP_OFFSET_PENALTY;
298 }
299 j
300}
301
302fn normalize_whitespace(text: &str) -> String {
303 text.split_whitespace().collect::<Vec<_>>().join(" ")
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_levenshtein_distance() {
312 assert_eq!(levenshtein::distance("abc".chars(), "abc".chars()), 0);
313 assert_eq!(levenshtein::distance("abc".chars(), "ab".chars()), 1);
314 assert_eq!(levenshtein::distance("ab".chars(), "abc".chars()), 1);
315 assert_eq!(levenshtein::distance("abc".chars(), "def".chars()), 3);
316 assert_eq!(levenshtein::distance("".chars(), "abc".chars()), 3);
317 assert_eq!(levenshtein::distance("abc".chars(), "".chars()), 3);
318 }
319
320 #[test]
321 fn test_normalize_whitespace() {
322 assert_eq!(normalize_whitespace(" hello world "), "hello world");
323 assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
324 }
325
326 #[test]
327 fn test_optimize_simple() {
328 let ground = "hello world";
329 let blocks = &["world", "hello"];
330 let (loss, result) = optimize(ground, blocks);
331 assert!(loss.abs() >= 0.0);
332 assert_eq!(result.len(), 2);
333 }
334
335 #[test]
336 fn test_apply_undo_shift() {
337 let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
338 apply_shift(&mut v, 2, 2, 5);
339 assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
340 undo_shift(&mut v, 2, 2, 5);
341 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
342
343 apply_shift(&mut v, 5, 2, 1);
344 assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
345 undo_shift(&mut v, 5, 2, 1);
346 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
347 }
348
349 #[test]
350 fn test_f64_ord_roundtrip() {
351 let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
352 for &v in &vals {
353 assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
354 }
355 assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
356 assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
357 }
358}