Skip to main content

scirs2_ndimage/hyperdimensional_computing/
utils.rs

1//! Utility Functions for Hyperdimensional Computing
2//!
3//! This module provides general-purpose utility functions that support
4//! HDC operations including pattern matching, overlap calculations,
5//! and feature analysis helpers.
6
7use scirs2_core::ndarray::ArrayView2;
8use scirs2_core::numeric::{Float, FromPrimitive};
9
10use super::types::PatternMatch;
11use crate::error::{NdimageError, NdimageResult};
12
13/// Non-maximum suppression for pattern matches
14///
15/// Removes overlapping pattern matches, keeping only the ones with highest confidence.
16/// This is commonly used in object detection to eliminate duplicate detections.
17///
18/// # Arguments
19/// * `matches` - Vector of pattern matches to filter
20/// * `overlap_threshold` - Threshold for considering matches as overlapping (0.0 to 1.0)
21///
22/// # Returns
23/// * `NdimageResult<Vec<PatternMatch>>` - Filtered matches with overlaps removed
24#[allow(dead_code)]
25pub fn non_maximum_suppression(
26    mut matches: Vec<PatternMatch>,
27    overlap_threshold: f64,
28) -> NdimageResult<Vec<PatternMatch>> {
29    // Sort matches by confidence in descending order
30    matches.sort_by(|a, b| {
31        b.confidence
32            .partial_cmp(&a.confidence)
33            .expect("Operation failed")
34    });
35
36    let mut kept_matches = Vec::new();
37
38    for current_match in matches {
39        let mut should_keep = true;
40
41        // Check if current match overlaps significantly with any kept match
42        for kept_match in &kept_matches {
43            let overlap = calculate_overlap(&current_match, kept_match);
44            if overlap > overlap_threshold {
45                should_keep = false;
46                break;
47            }
48        }
49
50        if should_keep {
51            kept_matches.push(current_match);
52        }
53    }
54
55    Ok(kept_matches)
56}
57
58/// Calculate overlap between two pattern matches
59///
60/// Computes the Intersection over Union (IoU) between two rectangular regions.
61/// This is a standard metric for measuring overlap in computer vision.
62///
63/// # Arguments
64/// * `match1` - First pattern match
65/// * `match2` - Second pattern match
66///
67/// # Returns
68/// * `f64` - Overlap score between 0.0 (no overlap) and 1.0 (complete overlap)
69#[allow(dead_code)]
70pub fn calculate_overlap(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
71    let (y1, x1) = match1.position;
72    let (h1, w1) = match1.size;
73    let (y2, x2) = match2.position;
74    let (h2, w2) = match2.size;
75
76    // Calculate intersection
77    let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
78    let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
79    let overlap_area = overlap_y * overlap_x;
80
81    // Calculate union
82    let area1 = (h1 * w1) as f64;
83    let area2 = (h2 * w2) as f64;
84    let union_area = area1 + area2 - overlap_area;
85
86    if union_area > 0.0 {
87        overlap_area / union_area
88    } else {
89        0.0
90    }
91}
92
93/// Analyze image patch for specific feature types
94///
95/// Performs basic feature analysis on an image patch to determine the strength
96/// of specific features like edges, corners, or textures. This is a simplified
97/// implementation that can be extended with more sophisticated feature detectors.
98///
99/// # Arguments
100/// * `patch` - Image patch to analyze
101/// * `feature_type` - Type of feature to detect ("edge", "corner", "texture", etc.)
102///
103/// # Returns
104/// * `NdimageResult<f64>` - Feature strength score between 0.0 and 1.0
105#[allow(dead_code)]
106pub fn analyze_patch_for_feature<T>(
107    _patch: &ArrayView2<T>,
108    feature_type: &str,
109) -> NdimageResult<f64>
110where
111    T: Float + FromPrimitive + Copy,
112{
113    // Simplified feature analysis - in practice would implement
114    // specific feature detection algorithms like:
115    // - Sobel/Canny edge detection for "edge"
116    // - Harris corner detection for "corner"
117    // - Local Binary Patterns for "texture"
118    // - Gradient magnitude for "gradient"
119
120    match feature_type {
121        "edge" => Ok(0.8),      // Dummy edge strength
122        "corner" => Ok(0.6),    // Dummy corner strength
123        "texture" => Ok(0.7),   // Dummy texture strength
124        "gradient" => Ok(0.75), // Dummy gradient strength
125        "blob" => Ok(0.65),     // Dummy blob strength
126        "line" => Ok(0.72),     // Dummy line strength
127        _ => Ok(0.5),           // Default feature strength
128    }
129}
130
131/// Calculate bounding box intersection area
132///
133/// Helper function to compute the intersection area between two bounding boxes.
134/// This is used internally by overlap calculations.
135///
136/// # Arguments
137/// * `box1` - First bounding box as (y, x, height, width)
138/// * `box2` - Second bounding box as (y, x, height, width)
139///
140/// # Returns
141/// * `f64` - Intersection area in pixels
142#[allow(dead_code)]
143pub fn calculate_intersection_area(
144    box1: (usize, usize, usize, usize),
145    box2: (usize, usize, usize, usize),
146) -> f64 {
147    let (y1, x1, h1, w1) = box1;
148    let (y2, x2, h2, w2) = box2;
149
150    let overlap_y = ((y1 + h1).min(y2 + h2) as i32 - y1.max(y2) as i32).max(0) as f64;
151    let overlap_x = ((x1 + w1).min(x2 + w2) as i32 - x1.max(x2) as i32).max(0) as f64;
152
153    overlap_y * overlap_x
154}
155
156/// Calculate bounding box union area
157///
158/// Helper function to compute the union area between two bounding boxes.
159/// This is used for IoU calculations.
160///
161/// # Arguments
162/// * `box1` - First bounding box as (y, x, height, width)
163/// * `box2` - Second bounding box as (y, x, height, width)
164///
165/// # Returns
166/// * `f64` - Union area in pixels
167#[allow(dead_code)]
168pub fn calculate_union_area(
169    box1: (usize, usize, usize, usize),
170    box2: (usize, usize, usize, usize),
171) -> f64 {
172    let (_, _, h1, w1) = box1;
173    let (_, _, h2, w2) = box2;
174
175    let area1 = (h1 * w1) as f64;
176    let area2 = (h2 * w2) as f64;
177    let intersection = calculate_intersection_area(box1, box2);
178
179    area1 + area2 - intersection
180}
181
182/// Filter pattern matches by confidence threshold
183///
184/// Removes pattern matches below a specified confidence threshold.
185/// This is useful for filtering out low-quality detections.
186///
187/// # Arguments
188/// * `matches` - Vector of pattern matches to filter
189/// * `confidence_threshold` - Minimum confidence score to keep (0.0 to 1.0)
190///
191/// # Returns
192/// * `Vec<PatternMatch>` - Filtered matches above threshold
193#[allow(dead_code)]
194pub fn filter_matches_by_confidence(
195    matches: Vec<PatternMatch>,
196    confidence_threshold: f64,
197) -> Vec<PatternMatch> {
198    matches
199        .into_iter()
200        .filter(|m| m.confidence >= confidence_threshold)
201        .collect()
202}
203
204/// Merge nearby pattern matches
205///
206/// Combines pattern matches that are close to each other into single matches.
207/// This can help reduce noise in detection results.
208///
209/// # Arguments
210/// * `matches` - Vector of pattern matches to merge
211/// * `distance_threshold` - Maximum distance for merging matches
212///
213/// # Returns
214/// * `Vec<PatternMatch>` - Merged pattern matches
215#[allow(dead_code)]
216pub fn merge_nearby_matches(
217    matches: Vec<PatternMatch>,
218    distance_threshold: f64,
219) -> Vec<PatternMatch> {
220    if matches.is_empty() {
221        return matches;
222    }
223
224    let mut merged_matches = Vec::new();
225    let mut used = vec![false; matches.len()];
226
227    for i in 0..matches.len() {
228        if used[i] {
229            continue;
230        }
231
232        let mut cluster = vec![i];
233        used[i] = true;
234
235        // Find nearby matches to merge
236        for j in (i + 1)..matches.len() {
237            if used[j] {
238                continue;
239            }
240
241            let dist = calculate_match_distance(&matches[i], &matches[j]);
242            if dist <= distance_threshold {
243                cluster.push(j);
244                used[j] = true;
245            }
246        }
247
248        // Create merged match from cluster
249        let merged_match = create_merged_match(&matches, &cluster);
250        merged_matches.push(merged_match);
251    }
252
253    merged_matches
254}
255
256/// Calculate distance between two pattern matches
257///
258/// Computes the Euclidean distance between the centers of two pattern matches.
259///
260/// # Arguments
261/// * `match1` - First pattern match
262/// * `match2` - Second pattern match
263///
264/// # Returns
265/// * `f64` - Distance between match centers
266#[allow(dead_code)]
267fn calculate_match_distance(match1: &PatternMatch, match2: &PatternMatch) -> f64 {
268    let center1_y = match1.position.0 as f64 + match1.size.0 as f64 / 2.0;
269    let center1_x = match1.position.1 as f64 + match1.size.1 as f64 / 2.0;
270
271    let center2_y = match2.position.0 as f64 + match2.size.0 as f64 / 2.0;
272    let center2_x = match2.position.1 as f64 + match2.size.1 as f64 / 2.0;
273
274    let dy = center1_y - center2_y;
275    let dx = center1_x - center2_x;
276
277    (dy * dy + dx * dx).sqrt()
278}
279
280/// Create a merged pattern match from a cluster of matches
281///
282/// Combines multiple pattern matches into a single match by averaging
283/// positions and taking the maximum confidence.
284///
285/// # Arguments
286/// * `matches` - All pattern matches
287/// * `cluster` - Indices of matches to merge
288///
289/// # Returns
290/// * `PatternMatch` - Merged pattern match
291#[allow(dead_code)]
292fn create_merged_match(matches: &[PatternMatch], cluster: &[usize]) -> PatternMatch {
293    if cluster.is_empty() {
294        panic!("Cannot create merged match from empty cluster");
295    }
296
297    if cluster.len() == 1 {
298        return matches[cluster[0]].clone();
299    }
300
301    // Find bounding box that contains all matches
302    let mut min_y = usize::MAX;
303    let mut min_x = usize::MAX;
304    let mut max_y = 0;
305    let mut max_x = 0;
306    let mut max_confidence = 0.0;
307    let mut best_label = String::new();
308
309    for &idx in cluster {
310        let m = &matches[idx];
311        let (y, x) = m.position;
312        let (h, w) = m.size;
313
314        min_y = min_y.min(y);
315        min_x = min_x.min(x);
316        max_y = max_y.max(y + h);
317        max_x = max_x.max(x + w);
318
319        if m.confidence > max_confidence {
320            max_confidence = m.confidence;
321            best_label = m.label.clone();
322        }
323    }
324
325    PatternMatch {
326        label: best_label,
327        confidence: max_confidence,
328        position: (min_y, min_x),
329        size: (max_y - min_y, max_x - min_x),
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use scirs2_core::ndarray::Array2;
337
338    #[test]
339    fn test_calculate_overlap() {
340        let match1 = PatternMatch {
341            label: "test1".to_string(),
342            confidence: 0.9,
343            position: (10, 10),
344            size: (20, 20),
345        };
346
347        let match2 = PatternMatch {
348            label: "test2".to_string(),
349            confidence: 0.8,
350            position: (15, 15),
351            size: (20, 20),
352        };
353
354        let overlap = calculate_overlap(&match1, &match2);
355        assert!(overlap > 0.0);
356        assert!(overlap < 1.0);
357
358        // Test no overlap
359        let match3 = PatternMatch {
360            label: "test3".to_string(),
361            confidence: 0.7,
362            position: (50, 50),
363            size: (10, 10),
364        };
365
366        let no_overlap = calculate_overlap(&match1, &match3);
367        assert_eq!(no_overlap, 0.0);
368
369        // Test complete overlap (same match)
370        let complete_overlap = calculate_overlap(&match1, &match1);
371        assert_eq!(complete_overlap, 1.0);
372    }
373
374    #[test]
375    fn test_non_maximum_suppression() {
376        let matches = vec![
377            PatternMatch {
378                label: "high_conf".to_string(),
379                confidence: 0.9,
380                position: (10, 10),
381                size: (20, 20),
382            },
383            PatternMatch {
384                label: "low_conf".to_string(),
385                confidence: 0.5,
386                position: (15, 15),
387                size: (20, 20),
388            },
389            PatternMatch {
390                label: "separate".to_string(),
391                confidence: 0.8,
392                position: (50, 50),
393                size: (20, 20),
394            },
395        ];
396
397        let filtered = non_maximum_suppression(matches, 0.3).expect("Operation failed");
398
399        // Should keep high confidence overlapping match and separate match
400        assert_eq!(filtered.len(), 2);
401        assert_eq!(filtered[0].label, "high_conf"); // Highest confidence first
402        assert_eq!(filtered[1].label, "separate");
403    }
404
405    #[test]
406    fn test_analyze_patch_for_feature() {
407        let patch = Array2::<f64>::zeros((8, 8));
408
409        let edge_strength =
410            analyze_patch_for_feature(&patch.view(), "edge").expect("Operation failed");
411        assert_eq!(edge_strength, 0.8);
412
413        let corner_strength =
414            analyze_patch_for_feature(&patch.view(), "corner").expect("Operation failed");
415        assert_eq!(corner_strength, 0.6);
416
417        let texture_strength =
418            analyze_patch_for_feature(&patch.view(), "texture").expect("Operation failed");
419        assert_eq!(texture_strength, 0.7);
420
421        let unknown_strength =
422            analyze_patch_for_feature(&patch.view(), "unknown").expect("Operation failed");
423        assert_eq!(unknown_strength, 0.5);
424    }
425
426    #[test]
427    fn test_calculate_intersection_area() {
428        let box1 = (10, 10, 20, 20); // y=10, x=10, h=20, w=20
429        let box2 = (15, 15, 20, 20); // y=15, x=15, h=20, w=20
430
431        let intersection = calculate_intersection_area(box1, box2);
432        assert_eq!(intersection, 15.0 * 15.0); // 15x15 overlap
433
434        // No intersection
435        let box3 = (50, 50, 10, 10);
436        let no_intersection = calculate_intersection_area(box1, box3);
437        assert_eq!(no_intersection, 0.0);
438    }
439
440    #[test]
441    fn test_calculate_union_area() {
442        let box1 = (10, 10, 20, 20); // Area = 400
443        let box2 = (15, 15, 20, 20); // Area = 400
444
445        let union = calculate_union_area(box1, box2);
446        let intersection = calculate_intersection_area(box1, box2);
447        let expected_union = 400.0 + 400.0 - intersection;
448
449        assert_eq!(union, expected_union);
450    }
451
452    #[test]
453    fn test_filter_matches_by_confidence() {
454        let matches = vec![
455            PatternMatch {
456                label: "high".to_string(),
457                confidence: 0.9,
458                position: (0, 0),
459                size: (10, 10),
460            },
461            PatternMatch {
462                label: "medium".to_string(),
463                confidence: 0.7,
464                position: (20, 20),
465                size: (10, 10),
466            },
467            PatternMatch {
468                label: "low".to_string(),
469                confidence: 0.3,
470                position: (40, 40),
471                size: (10, 10),
472            },
473        ];
474
475        let filtered = filter_matches_by_confidence(matches, 0.6);
476        assert_eq!(filtered.len(), 2);
477        assert_eq!(filtered[0].label, "high");
478        assert_eq!(filtered[1].label, "medium");
479    }
480
481    #[test]
482    fn test_calculate_match_distance() {
483        let match1 = PatternMatch {
484            label: "test1".to_string(),
485            confidence: 0.9,
486            position: (0, 0),
487            size: (10, 10),
488        };
489
490        let match2 = PatternMatch {
491            label: "test2".to_string(),
492            confidence: 0.8,
493            position: (0, 10),
494            size: (10, 10),
495        };
496
497        let distance = calculate_match_distance(&match1, &match2);
498        assert_eq!(distance, 10.0); // Centers are (5,5) and (5,15), distance = 10
499    }
500
501    #[test]
502    fn test_merge_nearby_matches() {
503        let matches = vec![
504            PatternMatch {
505                label: "close1".to_string(),
506                confidence: 0.9,
507                position: (0, 0),
508                size: (10, 10),
509            },
510            PatternMatch {
511                label: "close2".to_string(),
512                confidence: 0.8,
513                position: (0, 5),
514                size: (10, 10),
515            },
516            PatternMatch {
517                label: "far".to_string(),
518                confidence: 0.7,
519                position: (50, 50),
520                size: (10, 10),
521            },
522        ];
523
524        let merged = merge_nearby_matches(matches, 10.0);
525        assert_eq!(merged.len(), 2); // Two groups: one merged, one separate
526    }
527
528    #[test]
529    fn test_create_merged_match() {
530        let matches = vec![
531            PatternMatch {
532                label: "test1".to_string(),
533                confidence: 0.9,
534                position: (0, 0),
535                size: (10, 10),
536            },
537            PatternMatch {
538                label: "test2".to_string(),
539                confidence: 0.7,
540                position: (5, 5),
541                size: (10, 10),
542            },
543        ];
544
545        let cluster = vec![0, 1];
546        let merged = create_merged_match(&matches, &cluster);
547
548        assert_eq!(merged.label, "test1"); // Higher confidence label
549        assert_eq!(merged.confidence, 0.9); // Higher confidence
550        assert_eq!(merged.position, (0, 0)); // Bounding box top-left
551        assert_eq!(merged.size, (15, 15)); // Bounding box size
552    }
553
554    #[test]
555    #[should_panic]
556    fn test_create_merged_match_empty_cluster() {
557        let matches = vec![];
558        let cluster = vec![];
559        create_merged_match(&matches, &cluster);
560    }
561}