1use scirs2_core::ndarray::ArrayView2;
8use scirs2_core::numeric::{Float, FromPrimitive};
9
10use super::types::PatternMatch;
11use crate::error::{NdimageError, NdimageResult};
12
13#[allow(dead_code)]
25pub fn non_maximum_suppression(
26 mut matches: Vec<PatternMatch>,
27 overlap_threshold: f64,
28) -> NdimageResult<Vec<PatternMatch>> {
29 matches.sort_by(|a, b| {
31 b.confidence
32 .partial_cmp(&a.confidence)
33 .expect("Operation failed")
34 });
35
36 let mut kept_matches = Vec::new();
37
38 for current_match in matches {
39 let mut should_keep = true;
40
41 for kept_match in &kept_matches {
43 let overlap = calculate_overlap(¤t_match, kept_match);
44 if overlap > overlap_threshold {
45 should_keep = false;
46 break;
47 }
48 }
49
50 if should_keep {
51 kept_matches.push(current_match);
52 }
53 }
54
55 Ok(kept_matches)
56}
57
58#[allow(dead_code)]
70pub fn calculate_overlap(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
71 let (y1, x1) = match1.position;
72 let (h1, w1) = match1.size;
73 let (y2, x2) = match2.position;
74 let (h2, w2) = match2.size;
75
76 let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
78 let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
79 let overlap_area = overlap_y * overlap_x;
80
81 let area1 = (h1 * w1) as f64;
83 let area2 = (h2 * w2) as f64;
84 let union_area = area1 + area2 - overlap_area;
85
86 if union_area > 0.0 {
87 overlap_area / union_area
88 } else {
89 0.0
90 }
91}
92
93#[allow(dead_code)]
106pub fn analyze_patch_for_feature<T>(
107 _patch: &ArrayView2<T>,
108 feature_type: &str,
109) -> NdimageResult<f64>
110where
111 T: Float + FromPrimitive + Copy,
112{
113 match feature_type {
121 "edge" => Ok(0.8), "corner" => Ok(0.6), "texture" => Ok(0.7), "gradient" => Ok(0.75), "blob" => Ok(0.65), "line" => Ok(0.72), _ => Ok(0.5), }
129}
130
131#[allow(dead_code)]
143pub fn calculate_intersection_area(
144 box1: (usize, usize, usize, usize),
145 box2: (usize, usize, usize, usize),
146) -> f64 {
147 let (y1, x1, h1, w1) = box1;
148 let (y2, x2, h2, w2) = box2;
149
150 let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
151 let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
152
153 overlap_y * overlap_x
154}
155
156#[allow(dead_code)]
168pub fn calculate_union_area(
169 box1: (usize, usize, usize, usize),
170 box2: (usize, usize, usize, usize),
171) -> f64 {
172 let (_, _, h1, w1) = box1;
173 let (_, _, h2, w2) = box2;
174
175 let area1 = (h1 * w1) as f64;
176 let area2 = (h2 * w2) as f64;
177 let intersection = calculate_intersection_area(box1, box2);
178
179 area1 + area2 - intersection
180}
181
182#[allow(dead_code)]
194pub fn filter_matches_by_confidence(
195 matches: Vec<PatternMatch>,
196 confidence_threshold: f64,
197) -> Vec<PatternMatch> {
198 matches
199 .into_iter()
200 .filter(|m| m.confidence >= confidence_threshold)
201 .collect()
202}
203
204#[allow(dead_code)]
216pub fn merge_nearby_matches(
217 matches: Vec<PatternMatch>,
218 distance_threshold: f64,
219) -> Vec<PatternMatch> {
220 if matches.is_empty() {
221 return matches;
222 }
223
224 let mut merged_matches = Vec::new();
225 let mut used = vec![false; matches.len()];
226
227 for i in 0..matches.len() {
228 if used[i] {
229 continue;
230 }
231
232 let mut cluster = vec![i];
233 used[i] = true;
234
235 for j in (i + 1)..matches.len() {
237 if used[j] {
238 continue;
239 }
240
241 let dist = calculate_match_distance(&matches[i], &matches[j]);
242 if dist <= distance_threshold {
243 cluster.push(j);
244 used[j] = true;
245 }
246 }
247
248 let merged_match = create_merged_match(&matches, &cluster);
250 merged_matches.push(merged_match);
251 }
252
253 merged_matches
254}
255
256#[allow(dead_code)]
267fn calculate_match_distance(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
268 let center1_y = match1.position.0 as f64 + match1.size.0 as f64 / 2.0;
269 let center1_x = match1.position.1 as f64 + match1.size.1 as f64 / 2.0;
270
271 let center2_y = match2.position.0 as f64 + match2.size.0 as f64 / 2.0;
272 let center2_x = match2.position.1 as f64 + match2.size.1 as f64 / 2.0;
273
274 let dy = center1_y - center2_y;
275 let dx = center1_x - center2_x;
276
277 (dy * dy + dx * dx).sqrt()
278}
279
280#[allow(dead_code)]
292fn create_merged_match(matches: &[PatternMatch], cluster: &[usize]) -> PatternMatch {
293 if cluster.is_empty() {
294 panic!("Cannot create merged match from empty cluster");
295 }
296
297 if cluster.len() == 1 {
298 return matches[cluster[0]].clone();
299 }
300
301 let mut min_y = usize::MAX;
303 let mut min_x = usize::MAX;
304 let mut max_y = 0;
305 let mut max_x = 0;
306 let mut max_confidence = 0.0;
307 let mut best_label = String::new();
308
309 for &idx in cluster {
310 let m = &matches[idx];
311 let (y, x) = m.position;
312 let (h, w) = m.size;
313
314 min_y = min_y.min(y);
315 min_x = min_x.min(x);
316 max_y = max_y.max(y + h);
317 max_x = max_x.max(x + w);
318
319 if m.confidence > max_confidence {
320 max_confidence = m.confidence;
321 best_label = m.label.clone();
322 }
323 }
324
325 PatternMatch {
326 label: best_label,
327 confidence: max_confidence,
328 position: (min_y, min_x),
329 size: (max_y - min_y, max_x - min_x),
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use scirs2_core::ndarray::Array2;
337
338 #[test]
339 fn test_calculate_overlap() {
340 let match1 = PatternMatch {
341 label: "test1".to_string(),
342 confidence: 0.9,
343 position: (10, 10),
344 size: (20, 20),
345 };
346
347 let match2 = PatternMatch {
348 label: "test2".to_string(),
349 confidence: 0.8,
350 position: (15, 15),
351 size: (20, 20),
352 };
353
354 let overlap = calculate_overlap(&match1, &match2);
355 assert!(overlap > 0.0);
356 assert!(overlap < 1.0);
357
358 let match3 = PatternMatch {
360 label: "test3".to_string(),
361 confidence: 0.7,
362 position: (50, 50),
363 size: (10, 10),
364 };
365
366 let no_overlap = calculate_overlap(&match1, &match3);
367 assert_eq!(no_overlap, 0.0);
368
369 let complete_overlap = calculate_overlap(&match1, &match1);
371 assert_eq!(complete_overlap, 1.0);
372 }
373
374 #[test]
375 fn test_non_maximum_suppression() {
376 let matches = vec![
377 PatternMatch {
378 label: "high_conf".to_string(),
379 confidence: 0.9,
380 position: (10, 10),
381 size: (20, 20),
382 },
383 PatternMatch {
384 label: "low_conf".to_string(),
385 confidence: 0.5,
386 position: (15, 15),
387 size: (20, 20),
388 },
389 PatternMatch {
390 label: "separate".to_string(),
391 confidence: 0.8,
392 position: (50, 50),
393 size: (20, 20),
394 },
395 ];
396
397 let filtered = non_maximum_suppression(matches, 0.3).expect("Operation failed");
398
399 assert_eq!(filtered.len(), 2);
401 assert_eq!(filtered[0].label, "high_conf"); assert_eq!(filtered[1].label, "separate");
403 }
404
405 #[test]
406 fn test_analyze_patch_for_feature() {
407 let patch = Array2::<f64>::zeros((8, 8));
408
409 let edge_strength =
410 analyze_patch_for_feature(&patch.view(), "edge").expect("Operation failed");
411 assert_eq!(edge_strength, 0.8);
412
413 let corner_strength =
414 analyze_patch_for_feature(&patch.view(), "corner").expect("Operation failed");
415 assert_eq!(corner_strength, 0.6);
416
417 let texture_strength =
418 analyze_patch_for_feature(&patch.view(), "texture").expect("Operation failed");
419 assert_eq!(texture_strength, 0.7);
420
421 let unknown_strength =
422 analyze_patch_for_feature(&patch.view(), "unknown").expect("Operation failed");
423 assert_eq!(unknown_strength, 0.5);
424 }
425
426 #[test]
427 fn test_calculate_intersection_area() {
428 let box1 = (10, 10, 20, 20); let box2 = (15, 15, 20, 20); let intersection = calculate_intersection_area(box1, box2);
432 assert_eq!(intersection, 15.0 * 15.0); let box3 = (50, 50, 10, 10);
436 let no_intersection = calculate_intersection_area(box1, box3);
437 assert_eq!(no_intersection, 0.0);
438 }
439
440 #[test]
441 fn test_calculate_union_area() {
442 let box1 = (10, 10, 20, 20); let box2 = (15, 15, 20, 20); let union = calculate_union_area(box1, box2);
446 let intersection = calculate_intersection_area(box1, box2);
447 let expected_union = 400.0 + 400.0 - intersection;
448
449 assert_eq!(union, expected_union);
450 }
451
452 #[test]
453 fn test_filter_matches_by_confidence() {
454 let matches = vec![
455 PatternMatch {
456 label: "high".to_string(),
457 confidence: 0.9,
458 position: (0, 0),
459 size: (10, 10),
460 },
461 PatternMatch {
462 label: "medium".to_string(),
463 confidence: 0.7,
464 position: (20, 20),
465 size: (10, 10),
466 },
467 PatternMatch {
468 label: "low".to_string(),
469 confidence: 0.3,
470 position: (40, 40),
471 size: (10, 10),
472 },
473 ];
474
475 let filtered = filter_matches_by_confidence(matches, 0.6);
476 assert_eq!(filtered.len(), 2);
477 assert_eq!(filtered[0].label, "high");
478 assert_eq!(filtered[1].label, "medium");
479 }
480
481 #[test]
482 fn test_calculate_match_distance() {
483 let match1 = PatternMatch {
484 label: "test1".to_string(),
485 confidence: 0.9,
486 position: (0, 0),
487 size: (10, 10),
488 };
489
490 let match2 = PatternMatch {
491 label: "test2".to_string(),
492 confidence: 0.8,
493 position: (0, 10),
494 size: (10, 10),
495 };
496
497 let distance = calculate_match_distance(&match1, &match2);
498 assert_eq!(distance, 10.0); }
500
501 #[test]
502 fn test_merge_nearby_matches() {
503 let matches = vec![
504 PatternMatch {
505 label: "close1".to_string(),
506 confidence: 0.9,
507 position: (0, 0),
508 size: (10, 10),
509 },
510 PatternMatch {
511 label: "close2".to_string(),
512 confidence: 0.8,
513 position: (0, 5),
514 size: (10, 10),
515 },
516 PatternMatch {
517 label: "far".to_string(),
518 confidence: 0.7,
519 position: (50, 50),
520 size: (10, 10),
521 },
522 ];
523
524 let merged = merge_nearby_matches(matches, 10.0);
525 assert_eq!(merged.len(), 2); }
527
528 #[test]
529 fn test_create_merged_match() {
530 let matches = vec![
531 PatternMatch {
532 label: "test1".to_string(),
533 confidence: 0.9,
534 position: (0, 0),
535 size: (10, 10),
536 },
537 PatternMatch {
538 label: "test2".to_string(),
539 confidence: 0.7,
540 position: (5, 5),
541 size: (10, 10),
542 },
543 ];
544
545 let cluster = vec![0, 1];
546 let merged = create_merged_match(&matches, &cluster);
547
548 assert_eq!(merged.label, "test1"); assert_eq!(merged.confidence, 0.9); assert_eq!(merged.position, (0, 0)); assert_eq!(merged.size, (15, 15)); }
553
554 #[test]
555 #[should_panic]
556 fn test_create_merged_match_empty_cluster() {
557 let matches = vec![];
558 let cluster = vec![];
559 create_merged_match(&matches, &cluster);
560 }
561}