Skip to main content

scirs2_ndimage/
template_matching.rs

1//! Template Matching and Sliding Window Detection
2//!
3//! This module provides:
4//! - Template matching with multiple similarity metrics (SSD, NCC, coefficient correlation)
5//! - Non-maximum suppression to find discrete match locations
6//! - Multi-scale (image pyramid) template matching
7//!
8//! # References
9//! - Lewis, J.P. (1995). "Fast Template Matching." Vision Interface.
10//! - Briechle, K. & Hanebeck, U.D. (2001). "Template Matching using Fast Normalized
11//!   Cross Correlation." Proc. SPIE 4387.
12
13use crate::error::{NdimageError, NdimageResult};
14use scirs2_core::ndarray::{s, Array2, Array3};
15use std::f64::consts::PI;
16
17// ---------------------------------------------------------------------------
18// Public API types
19// ---------------------------------------------------------------------------
20
21/// Template matching similarity measure
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MatchMethod {
24    /// Sum of squared differences (lower is better match)
25    SumSquaredDiff,
26    /// Normalized sum of squared differences in [0, 1] (lower is better)
27    NormalizedSumSquaredDiff,
28    /// Normalized cross-correlation in [-1, 1] (higher is better)
29    NormalizedCrossCorrelation,
30    /// Zero-mean normalized cross-correlation in [-1, 1] (higher is better)
31    CoeffCorrelation,
32}
33
34// ---------------------------------------------------------------------------
35// Core template matching
36// ---------------------------------------------------------------------------
37
38/// Compute a template-match response map for the given similarity measure.
39///
40/// The output array has shape `(image_rows - template_rows + 1,
41/// image_cols - template_cols + 1)` — one value per valid placement
42/// of the template inside the image.
43///
44/// For `SumSquaredDiff` / `NormalizedSumSquaredDiff` a **lower** value
45/// indicates a better match.  For `NormalizedCrossCorrelation` /
46/// `CoeffCorrelation` a **higher** value (closer to 1) is better.
47///
48/// # Errors
49/// Returns `NdimageError::InvalidInput` when the template is larger than
50/// the image in either dimension.
51pub fn template_match(
52    image: &Array2<f64>,
53    template: &Array2<f64>,
54    method: MatchMethod,
55) -> NdimageResult<Array2<f64>> {
56    let (ih, iw) = image.dim();
57    let (th, tw) = template.dim();
58
59    if th == 0 || tw == 0 {
60        return Err(NdimageError::InvalidInput(
61            "Template must not be empty".into(),
62        ));
63    }
64    if th > ih || tw > iw {
65        return Err(NdimageError::InvalidInput(
66            "Template must not be larger than the image".into(),
67        ));
68    }
69
70    match method {
71        MatchMethod::SumSquaredDiff => ssd_map(image, template, false),
72        MatchMethod::NormalizedSumSquaredDiff => ssd_map(image, template, true),
73        MatchMethod::NormalizedCrossCorrelation => normalized_cross_correlation(image, template),
74        MatchMethod::CoeffCorrelation => coeff_correlation(image, template),
75    }
76}
77
78// ---------------------------------------------------------------------------
79// SSD response map
80// ---------------------------------------------------------------------------
81
82fn ssd_map(
83    image: &Array2<f64>,
84    template: &Array2<f64>,
85    normalize: bool,
86) -> NdimageResult<Array2<f64>> {
87    let (ih, iw) = image.dim();
88    let (th, tw) = template.dim();
89    let out_h = ih - th + 1;
90    let out_w = iw - tw + 1;
91
92    // Template sum-of-squares for normalization
93    let template_ss: f64 = template.iter().map(|&v| v * v).sum();
94
95    let mut result = Array2::zeros((out_h, out_w));
96
97    for r in 0..out_h {
98        for c in 0..out_w {
99            let patch = image.slice(s![r..r + th, c..c + tw]);
100            let mut ssd = 0.0;
101            for (iv, tv) in patch.iter().zip(template.iter()) {
102                let d = iv - tv;
103                ssd += d * d;
104            }
105
106            if normalize {
107                // Normalized SSD: SSD / (||patch|| * ||template||)
108                let patch_ss: f64 = patch.iter().map(|&v| v * v).sum();
109                let denom = (patch_ss * template_ss).sqrt();
110                result[[r, c]] = if denom > 1e-12 { ssd / denom } else { 0.0 };
111            } else {
112                result[[r, c]] = ssd;
113            }
114        }
115    }
116
117    Ok(result)
118}
119
120// ---------------------------------------------------------------------------
121// Normalized cross-correlation
122// ---------------------------------------------------------------------------
123
124/// Compute the normalized cross-correlation (NCC) response map.
125///
126/// Each output pixel is
127/// ```text
128///   NCC(r,c) = sum(patch * template) / (||patch|| * ||template||)
129/// ```
130/// Values lie in [-1, 1].  A value of 1 means perfect correlation.
131///
132/// # Errors
133/// Returns `NdimageError::InvalidInput` when the template is larger than
134/// the image in either dimension.
135pub fn normalized_cross_correlation(
136    image: &Array2<f64>,
137    template: &Array2<f64>,
138) -> NdimageResult<Array2<f64>> {
139    let (ih, iw) = image.dim();
140    let (th, tw) = template.dim();
141
142    if th == 0 || tw == 0 {
143        return Err(NdimageError::InvalidInput(
144            "Template must not be empty".into(),
145        ));
146    }
147    if th > ih || tw > iw {
148        return Err(NdimageError::InvalidInput(
149            "Template must not be larger than the image".into(),
150        ));
151    }
152
153    let out_h = ih - th + 1;
154    let out_w = iw - tw + 1;
155
156    let template_norm: f64 = template.iter().map(|&v| v * v).sum::<f64>().sqrt();
157
158    let mut result = Array2::zeros((out_h, out_w));
159
160    for r in 0..out_h {
161        for c in 0..out_w {
162            let patch = image.slice(s![r..r + th, c..c + tw]);
163            let cross: f64 = patch.iter().zip(template.iter()).map(|(a, b)| a * b).sum();
164            let patch_norm: f64 = patch.iter().map(|&v| v * v).sum::<f64>().sqrt();
165            let denom = patch_norm * template_norm;
166            result[[r, c]] = if denom > 1e-12 { cross / denom } else { 0.0 };
167        }
168    }
169
170    Ok(result)
171}
172
173// ---------------------------------------------------------------------------
174// Zero-mean normalized cross-correlation (coefficient correlation)
175// ---------------------------------------------------------------------------
176
177fn coeff_correlation(image: &Array2<f64>, template: &Array2<f64>) -> NdimageResult<Array2<f64>> {
178    let (ih, iw) = image.dim();
179    let (th, tw) = template.dim();
180    let out_h = ih - th + 1;
181    let out_w = iw - tw + 1;
182    let n = (th * tw) as f64;
183
184    // Zero-mean template
185    let t_mean: f64 = template.iter().sum::<f64>() / n;
186    let t_centered: Vec<f64> = template.iter().map(|&v| v - t_mean).collect();
187    let t_std: f64 = t_centered.iter().map(|&v| v * v).sum::<f64>().sqrt();
188
189    let mut result = Array2::zeros((out_h, out_w));
190
191    for r in 0..out_h {
192        for c in 0..out_w {
193            let patch = image.slice(s![r..r + th, c..c + tw]);
194            let p_mean: f64 = patch.iter().sum::<f64>() / n;
195            let cross: f64 = patch
196                .iter()
197                .zip(t_centered.iter())
198                .map(|(a, b)| (a - p_mean) * b)
199                .sum();
200            let p_std: f64 = patch
201                .iter()
202                .map(|&v| (v - p_mean).powi(2))
203                .sum::<f64>()
204                .sqrt();
205            let denom = p_std * t_std;
206            result[[r, c]] = if denom > 1e-12 { cross / denom } else { 0.0 };
207        }
208    }
209
210    Ok(result)
211}
212
213// ---------------------------------------------------------------------------
214// Peak / match extraction with non-maximum suppression
215// ---------------------------------------------------------------------------
216
217/// Extract discrete match locations from a response map.
218///
219/// For `NormalizedCrossCorrelation` / `CoeffCorrelation` maps, peaks **above**
220/// `threshold` are returned (score closer to 1 = better).
221/// For `SumSquaredDiff` maps, use negated or inverted scores before calling
222/// this function, or supply negative scores.
223///
224/// # Parameters
225/// - `correlation_map` – 2-D response map produced by `template_match` (or
226///   `normalized_cross_correlation`).
227/// - `threshold` – minimum score to be considered a match.
228/// - `min_distance` – minimum pixel distance between two accepted peaks
229///   (non-maximum suppression radius).
230///
231/// # Returns
232/// A sorted (descending by score) list of `(row, col, score)` tuples.
233pub fn find_matches(
234    correlation_map: &Array2<f64>,
235    threshold: f64,
236    min_distance: usize,
237) -> NdimageResult<Vec<(usize, usize, f64)>> {
238    let (rows, cols) = correlation_map.dim();
239    if rows == 0 || cols == 0 {
240        return Ok(Vec::new());
241    }
242
243    // Collect all above-threshold positions
244    let mut candidates: Vec<(usize, usize, f64)> = correlation_map
245        .indexed_iter()
246        .filter_map(|((r, c), &score)| {
247            if score >= threshold {
248                Some((r, c, score))
249            } else {
250                None
251            }
252        })
253        .collect();
254
255    // Sort descending by score
256    candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
257
258    // Non-maximum suppression: greedily accept peaks that are at least
259    // `min_distance` pixels away from all already-accepted peaks.
260    let mut accepted: Vec<(usize, usize, f64)> = Vec::new();
261    let min_dist_sq = (min_distance as f64) * (min_distance as f64);
262
263    'outer: for (r, c, score) in candidates {
264        for &(ar, ac, _) in &accepted {
265            let dr = r as f64 - ar as f64;
266            let dc = c as f64 - ac as f64;
267            if dr * dr + dc * dc < min_dist_sq {
268                continue 'outer;
269            }
270        }
271        accepted.push((r, c, score));
272    }
273
274    Ok(accepted)
275}
276
277// ---------------------------------------------------------------------------
278// Multi-scale (image pyramid) template matching
279// ---------------------------------------------------------------------------
280
281/// Down-sample a 2-D image by factor 2 using simple 2×2 average pooling.
282fn downsample_2x(image: &Array2<f64>) -> Array2<f64> {
283    let (h, w) = image.dim();
284    let oh = h / 2;
285    let ow = w / 2;
286    if oh == 0 || ow == 0 {
287        return image.clone();
288    }
289    let mut out = Array2::zeros((oh, ow));
290    for r in 0..oh {
291        for c in 0..ow {
292            out[[r, c]] = 0.25
293                * (image[[2 * r, 2 * c]]
294                    + image[[2 * r, 2 * c + 1]]
295                    + image[[2 * r + 1, 2 * c]]
296                    + image[[2 * r + 1, 2 * c + 1]]);
297        }
298    }
299    out
300}
301
302/// Multi-scale template matching using a Gaussian image pyramid.
303///
304/// Builds `n_scales` octave-spaced scales of the image (each half the
305/// previous dimensions) and runs normalized cross-correlation at each
306/// scale.  Detected peaks are mapped back to the original-image
307/// coordinate space and deduplicated with non-maximum suppression.
308///
309/// # Parameters
310/// - `image`    – original grayscale image.
311/// - `template` – template to search for.
312/// - `n_scales` – number of pyramid levels (≥ 1).
313///
314/// # Returns
315/// List of `(row, col, score, scale)` tuples sorted descending by score,
316/// where `scale` is the zoom factor at which the match was found
317/// (1.0 = original size, 0.5 = half size, …).
318///
319/// # Errors
320/// Returns `NdimageError::InvalidInput` for degenerate inputs.
321pub fn pyramid_template_match(
322    image: &Array2<f64>,
323    template: &Array2<f64>,
324    n_scales: usize,
325) -> NdimageResult<Vec<(usize, usize, f64, f64)>> {
326    if n_scales == 0 {
327        return Err(NdimageError::InvalidInput(
328            "n_scales must be at least 1".into(),
329        ));
330    }
331    if template.dim().0 == 0 || template.dim().1 == 0 {
332        return Err(NdimageError::InvalidInput(
333            "Template must not be empty".into(),
334        ));
335    }
336
337    let (th, tw) = template.dim();
338    let mut results: Vec<(usize, usize, f64, f64)> = Vec::new();
339
340    let mut current_image = image.clone();
341    let mut current_template = template.clone();
342    let mut scale = 1.0_f64;
343
344    for _lvl in 0..n_scales {
345        let (ih, iw) = current_image.dim();
346        let (cth, ctw) = current_template.dim();
347
348        // Stop if the template no longer fits the image
349        if cth == 0 || ctw == 0 || cth > ih || ctw > iw {
350            break;
351        }
352
353        let ncc = normalized_cross_correlation(&current_image, &current_template)?;
354
355        // Threshold: accept matches with NCC ≥ 0.5 (reasonable default)
356        let threshold = 0.5;
357        let min_dist = (th.max(tw) / 2).max(1);
358        let local_matches = find_matches(&ncc, threshold, min_dist)?;
359
360        for (r, c, score) in local_matches {
361            // Map coordinates back to original image space
362            let orig_r = (r as f64 / scale).round() as usize;
363            let orig_c = (c as f64 / scale).round() as usize;
364            results.push((orig_r, orig_c, score, scale));
365        }
366
367        // Build next pyramid level
368        current_image = downsample_2x(&current_image);
369        current_template = downsample_2x(&current_template);
370        scale *= 0.5;
371    }
372
373    // Sort by score descending
374    results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
375
376    // Global NMS across all scales in original coordinates
377    let nms_dist: usize = (th.max(tw) / 2).max(1);
378    let min_dist_sq = (nms_dist as f64).powi(2);
379    let mut accepted: Vec<(usize, usize, f64, f64)> = Vec::new();
380
381    'outer: for (r, c, score, s) in results {
382        for &(ar, ac, _, _) in &accepted {
383            let dr = r as f64 - ar as f64;
384            let dc = c as f64 - ac as f64;
385            if dr * dr + dc * dc < min_dist_sq {
386                continue 'outer;
387            }
388        }
389        accepted.push((r, c, score, s));
390    }
391
392    Ok(accepted)
393}
394
395// ---------------------------------------------------------------------------
396// Tests
397// ---------------------------------------------------------------------------
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use scirs2_core::ndarray::Array2;
403
404    fn checkerboard_image(rows: usize, cols: usize) -> Array2<f64> {
405        Array2::from_shape_fn(
406            (rows, cols),
407            |(r, c)| {
408                if (r + c) % 2 == 0 {
409                    1.0
410                } else {
411                    0.0
412                }
413            },
414        )
415    }
416
417    #[test]
418    fn test_ssd_perfect_match() {
419        let image: Array2<f64> = Array2::from_shape_vec(
420            (4, 4),
421            vec![
422                1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0,
423            ],
424        )
425        .expect("shape ok");
426
427        let template: Array2<f64> =
428            Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).expect("shape ok");
429
430        let map = template_match(&image, &template, MatchMethod::SumSquaredDiff).expect("ssd ok");
431        // Perfect match at (0,0): SSD = 0
432        assert!(
433            map[[0, 0]] < 1e-12,
434            "Expected zero SSD at perfect-match location"
435        );
436    }
437
438    #[test]
439    fn test_ncc_perfect_match() {
440        let img = checkerboard_image(6, 6);
441        let tpl = img.slice(s![1..3, 1..3]).to_owned();
442        let ncc = normalized_cross_correlation(&img, &tpl).expect("ncc ok");
443        // Find the patch at (1,1) — should give NCC ~ 1
444        let score = ncc[[1, 1]];
445        assert!(
446            score > 0.99,
447            "NCC at matching position should be ~1, got {score}"
448        );
449    }
450
451    #[test]
452    fn test_find_matches_basic() {
453        let mut map: Array2<f64> = Array2::zeros((10, 10));
454        map[[2, 3]] = 0.9;
455        map[[7, 8]] = 0.8;
456        map[[2, 4]] = 0.85; // close to (2,3); should be suppressed
457
458        let matches = find_matches(&map, 0.7, 3).expect("matches ok");
459        assert!(!matches.is_empty());
460        // First match should be the highest-score one
461        assert_eq!(matches[0], (2, 3, 0.9));
462    }
463
464    #[test]
465    fn test_pyramid_match_runs() {
466        let image = checkerboard_image(32, 32);
467        let template: Array2<f64> = image.slice(s![4..8, 4..8]).to_owned();
468        let results = pyramid_template_match(&image, &template, 3).expect("pyramid ok");
469        // Should produce at least one result
470        assert!(!results.is_empty());
471    }
472
473    #[test]
474    fn test_template_larger_than_image_errors() {
475        let small: Array2<f64> = Array2::zeros((3, 3));
476        let large: Array2<f64> = Array2::zeros((5, 5));
477        let err = template_match(&small, &large, MatchMethod::SumSquaredDiff);
478        assert!(err.is_err());
479    }
480}