1use std::cmp;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::Instant;
4
5use rapidfuzz::distance::indel;
6use rayon::prelude::*;
7use tracing::info;
8
9const JUMP_STATIC_PENALTY: f64 = 2.0;
12const JUMP_OFFSET_PENALTY: f64 = 0.2;
13
14const BLOCK_SIZES: &[usize] = &[1, 2, 3, 4, 6, 9, 12, 15];
16const BASE_BLOCK_SHIFT: usize = 30;
18const SHIFT_PER_BLOCK_SIZE: usize = 8;
20
21type IndexedBlock<'a> = (usize, &'a str);
22
23#[derive(Debug, Clone)]
24pub struct MatchingLoss {
25 pub fuzzy: f64,
26 pub jump_distance: f64,
27 pub geo_distance: f64,
28}
29
30impl MatchingLoss {
31 pub fn abs(&self) -> f64 {
32 self.fuzzy + self.jump_distance + self.geo_distance
33 }
34}
35
36#[inline]
39fn f64_to_ord_u64(v: f64) -> u64 {
40 debug_assert!(v >= 0.0);
41 v.to_bits()
42}
43
44#[inline]
45fn ord_u64_to_f64(v: u64) -> f64 {
46 f64::from_bits(v)
47}
48
49pub fn optimize<'a>(ground: &str, blocks: &[&'a str]) -> (MatchingLoss, Vec<IndexedBlock<'a>>) {
60 let ground_normalized = normalize_whitespace(ground);
61 let ground_words: Vec<&str> = ground_normalized.split_whitespace().collect();
62 let ground_concat = ground_words.join(" ");
63
64 let scorer = indel::BatchComparator::new(ground_concat.bytes());
66 let ground_len = ground_concat.len();
67
68 let total_block_bytes: usize =
70 blocks.iter().map(|b| b.len()).sum::<usize>() + blocks.len().saturating_sub(1);
71 let mut main_buf = Vec::<u8>::with_capacity(total_block_bytes);
72
73 let mut matching: Vec<IndexedBlock<'a>> =
75 blocks.iter().enumerate().map(|(a, b)| (a, *b)).collect();
76
77 let mut optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
78 info!("Start fuzzy matching with loss of {:?}", optimum_loss);
79
80 let shift_count = AtomicU64::new(0);
81 let start_time = Instant::now();
82 let n = matching.len();
83
84 for &block_size in BLOCK_SIZES.iter().rev() {
85 let level_start = Instant::now();
86 let mut level_improvements = 0u64;
87 let max_shift = BASE_BLOCK_SHIFT + block_size * SHIFT_PER_BLOCK_SIZE;
88 info!(
89 "Optimizing on block level {} (max_shift={})",
90 block_size, max_shift
91 );
92
93 loop {
94 let shared_cutoff = AtomicU64::new(f64_to_ord_u64(optimum_loss.abs()));
97 let snapshot = matching.clone();
98
99 let best = (0..n)
100 .into_par_iter()
101 .map_init(
102 || {
103 (
104 snapshot.clone(),
105 Vec::<u8>::with_capacity(total_block_bytes),
106 )
107 },
108 |(local_matching, buf), start_pos| {
109 let end_pos = cmp::min(start_pos + block_size, n);
110 let actual_bs = end_pos - start_pos;
111 let lo = start_pos.saturating_sub(max_shift);
112 let hi = cmp::min(n - actual_bs + 1, start_pos + max_shift);
113
114 let mut best_local: Option<(f64, usize)> = None;
115 let mut local_count = 0u64;
116
117 for target in lo..hi {
118 if target == start_pos {
119 continue;
120 }
121 local_count += 1;
122
123 let cutoff = ord_u64_to_f64(shared_cutoff.load(Ordering::Relaxed));
125
126 apply_shift(local_matching, start_pos, actual_bs, target);
127
128 let loss = score_matching(
129 &scorer,
130 ground_len,
131 local_matching,
132 buf,
133 Some(cutoff),
134 );
135
136 let loss_val = loss.abs();
137 if loss_val < cutoff {
138 best_local = Some((loss_val, target));
139 shared_cutoff
141 .fetch_min(f64_to_ord_u64(loss_val), Ordering::Relaxed);
142 }
143
144 undo_shift(local_matching, start_pos, actual_bs, target);
145 }
146
147 shift_count.fetch_add(local_count, Ordering::Relaxed);
148 best_local.map(|(loss, target)| (loss, start_pos, target))
149 },
150 )
151 .flatten()
152 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
153
154 if let Some((_best_loss, best_start, best_target)) = best {
155 let end_pos = cmp::min(best_start + block_size, n);
156 let actual_bs = end_pos - best_start;
157 apply_shift(&mut matching, best_start, actual_bs, best_target);
158 let old = optimum_loss.abs();
159 optimum_loss = score_matching(&scorer, ground_len, &matching, &mut main_buf, None);
160 level_improvements += 1;
161 info!(
162 " bs={:02} loss={:.1} -> {:.1} | {} trials, {:.2?}",
163 block_size,
164 old,
165 optimum_loss.abs(),
166 shift_count.load(Ordering::Relaxed),
167 start_time.elapsed()
168 );
169 } else {
170 break;
171 }
172
173 if shift_count.load(Ordering::Relaxed) > 4_000_000 {
174 break;
175 }
176 }
177 info!(
178 " level {} done: {} improvements, {:.2?} elapsed (total {:.2?})",
179 block_size,
180 level_improvements,
181 level_start.elapsed(),
182 start_time.elapsed()
183 );
184 }
185
186 info!(
187 "Returning optimum with {:?} after {} trials ({:.2?})",
188 optimum_loss,
189 shift_count.load(Ordering::Relaxed),
190 start_time.elapsed()
191 );
192 (optimum_loss, matching)
193}
194
195#[inline]
197fn apply_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
198 if dst < src {
199 slice[dst..src + len].rotate_right(len);
200 } else {
201 slice[src..dst + len].rotate_left(len);
202 }
203}
204
205#[inline]
207fn undo_shift<T>(slice: &mut [T], src: usize, len: usize, dst: usize) {
208 if dst < src {
209 slice[dst..src + len].rotate_left(len);
210 } else {
211 slice[src..dst + len].rotate_right(len);
212 }
213}
214
215fn score_matching(
217 scorer: &indel::BatchComparator<u8>,
218 ground_len: usize,
219 matching: &[IndexedBlock],
220 buf: &mut Vec<u8>,
221 best_known: Option<f64>,
222) -> MatchingLoss {
223 let jump_distance = calculate_jump_distance(matching);
225 if let Some(best) = best_known {
226 if jump_distance >= best {
227 return MatchingLoss {
228 fuzzy: f64::MAX,
229 jump_distance,
230 geo_distance: 0.0,
231 };
232 }
233 }
234
235 buf.clear();
237 for (i, (_, text)) in matching.iter().enumerate() {
238 if i > 0 {
239 buf.push(b' ');
240 }
241 buf.extend_from_slice(text.as_bytes());
242 }
243
244 let indel_dist = if let Some(best) = best_known {
246 let max_fuzzy = (best - jump_distance).max(0.0) as usize;
247 let args = indel::Args::default()
248 .score_hint(max_fuzzy)
249 .score_cutoff(max_fuzzy);
250 match scorer.distance_with_args(buf.iter().copied(), &args) {
251 Some(d) => d as f64,
252 None => {
253 return MatchingLoss {
254 fuzzy: (ground_len + buf.len()) as f64,
255 jump_distance,
256 geo_distance: 0.0,
257 };
258 }
259 }
260 } else {
261 scorer.distance(buf.iter().copied()) as f64
262 };
263
264 MatchingLoss {
265 fuzzy: indel_dist,
266 jump_distance,
267 geo_distance: 0.0,
268 }
269}
270
271fn calculate_jump_distance(matching: &[IndexedBlock]) -> f64 {
272 let mut j = 0.0;
273 for i in 1..matching.len() {
274 let left = matching[i - 1];
275 let right = matching[i];
276 let expected_right = left.0 + 1;
277 let distance = (expected_right as i32 - right.0 as i32).abs() as usize;
278 if distance == 0 {
279 continue;
280 }
281 j += JUMP_STATIC_PENALTY + (distance as f64) * JUMP_OFFSET_PENALTY;
282 }
283 j
284}
285
286fn normalize_whitespace(text: &str) -> String {
287 text.split_whitespace().collect::<Vec<_>>().join(" ")
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_indel_distance() {
296 assert_eq!(indel::distance("abc".chars(), "abc".chars()), 0);
297 assert_eq!(indel::distance("abc".chars(), "ab".chars()), 1);
298 assert_eq!(indel::distance("ab".chars(), "abc".chars()), 1);
299 assert_eq!(indel::distance("abc".chars(), "def".chars()), 6);
300 assert_eq!(indel::distance("".chars(), "abc".chars()), 3);
301 assert_eq!(indel::distance("abc".chars(), "".chars()), 3);
302 }
303
304 #[test]
305 fn test_normalize_whitespace() {
306 assert_eq!(normalize_whitespace(" hello world "), "hello world");
307 assert_eq!(normalize_whitespace("a\nb\tc"), "a b c");
308 }
309
310 #[test]
311 fn test_optimize_simple() {
312 let ground = "hello world";
313 let blocks = &["world", "hello"];
314 let (loss, result) = optimize(ground, blocks);
315 assert!(loss.abs() >= 0.0);
316 assert_eq!(result.len(), 2);
317 }
318
319 #[test]
320 fn test_apply_undo_shift() {
321 let mut v = vec![0, 1, 2, 3, 4, 5, 6, 7];
322 apply_shift(&mut v, 2, 2, 5);
323 assert_eq!(v, vec![0, 1, 4, 5, 6, 2, 3, 7]);
324 undo_shift(&mut v, 2, 2, 5);
325 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
326
327 apply_shift(&mut v, 5, 2, 1);
328 assert_eq!(v, vec![0, 5, 6, 1, 2, 3, 4, 7]);
329 undo_shift(&mut v, 5, 2, 1);
330 assert_eq!(v, vec![0, 1, 2, 3, 4, 5, 6, 7]);
331 }
332
333 #[test]
334 fn test_f64_ord_roundtrip() {
335 let vals = [0.0, 1.0, 100.5, 315.6, 1000.0];
336 for &v in &vals {
337 assert_eq!(v, ord_u64_to_f64(f64_to_ord_u64(v)));
338 }
339 assert!(f64_to_ord_u64(1.0) < f64_to_ord_u64(2.0));
340 assert!(f64_to_ord_u64(100.0) < f64_to_ord_u64(315.6));
341 }
342}