Skip to main content

yscv_detect/
model_detector.rs

1//! Trait-based model detector interface.
2//!
3//! Provides an abstract `ModelDetector` trait so that model-backed detectors
4//! (YOLO, SSD, FCOS, etc.) can be plugged in alongside the existing heatmap
5//! pipeline.
6
7use yscv_tensor::Tensor;
8
9use crate::{DetectError, Detection};
10
11/// Configuration for a model-based detector.
12#[derive(Debug, Clone)]
13pub struct ModelDetectorConfig {
14    /// Minimum score to keep a detection.
15    pub score_threshold: f32,
16    /// IoU threshold for NMS post-processing.
17    pub nms_iou_threshold: f32,
18    /// Maximum number of detections to return.
19    pub max_detections: usize,
20    /// Expected input height for the model.
21    pub input_height: usize,
22    /// Expected input width for the model.
23    pub input_width: usize,
24}
25
26impl Default for ModelDetectorConfig {
27    fn default() -> Self {
28        Self {
29            score_threshold: 0.5,
30            nms_iou_threshold: 0.45,
31            max_detections: 100,
32            input_height: 640,
33            input_width: 640,
34        }
35    }
36}
37
38/// Abstract interface for model-backed object detectors.
39///
40/// Implementors provide `detect_tensor` which takes a preprocessed input
41/// tensor and returns detections. The framework handles NMS and thresholding
42/// via the config.
43pub trait ModelDetector {
44    /// Run detection on a preprocessed input tensor.
45    ///
46    /// The tensor format depends on the model. Typical NHWC shape:
47    /// `[1, H, W, C]` where `H` and `W` match the config dimensions.
48    ///
49    /// Returns raw detections (before NMS). The caller may apply
50    /// `non_max_suppression` from this crate.
51    fn detect_tensor(&self, input: &Tensor) -> Result<Vec<Detection>, DetectError>;
52
53    /// Returns the class labels this detector can produce.
54    fn class_labels(&self) -> &[&str];
55
56    /// Returns the expected input shape `[H, W, C]`.
57    fn input_shape(&self) -> [usize; 3];
58}
59
60/// Post-processes raw model output into final detections.
61///
62/// Applies score thresholding and NMS.
63pub fn postprocess_detections(raw: &[Detection], config: &ModelDetectorConfig) -> Vec<Detection> {
64    // Score filter
65    let mut filtered: Vec<Detection> = raw
66        .iter()
67        .copied()
68        .filter(|d| d.score >= config.score_threshold)
69        .collect();
70
71    // Sort by score descending
72    filtered.sort_by(|a, b| {
73        b.score
74            .partial_cmp(&a.score)
75            .unwrap_or(std::cmp::Ordering::Equal)
76    });
77
78    // NMS per class
79    let mut result = Vec::new();
80    let mut suppressed = vec![false; filtered.len()];
81
82    for i in 0..filtered.len() {
83        if suppressed[i] {
84            continue;
85        }
86        result.push(filtered[i]);
87        if result.len() >= config.max_detections {
88            break;
89        }
90        for j in i + 1..filtered.len() {
91            if suppressed[j] || filtered[j].class_id != filtered[i].class_id {
92                continue;
93            }
94            if crate::iou(filtered[i].bbox, filtered[j].bbox) > config.nms_iou_threshold {
95                suppressed[j] = true;
96            }
97        }
98    }
99
100    result
101}
102
103/// Preprocess an RGB8 image for model input.
104///
105/// Resizes to `(target_h, target_w)`, normalizes to `[0, 1]`, and returns
106/// an NHWC tensor `[1, target_h, target_w, 3]`.
107pub fn preprocess_rgb8_for_model(
108    rgb8: &[u8],
109    width: usize,
110    height: usize,
111    target_h: usize,
112    target_w: usize,
113) -> Result<Tensor, DetectError> {
114    if rgb8.len() < width * height * 3 {
115        return Err(DetectError::InvalidRgb8BufferSize {
116            expected: width * height * 3,
117            got: rgb8.len(),
118        });
119    }
120
121    // Simple bilinear resize + normalize
122    let mut data = Vec::with_capacity(target_h * target_w * 3);
123    let scale_y = height as f32 / target_h as f32;
124    let scale_x = width as f32 / target_w as f32;
125
126    for row in 0..target_h {
127        let src_y = (row as f32 * scale_y).min((height - 1) as f32);
128        let y0 = src_y as usize;
129        let y1 = (y0 + 1).min(height - 1);
130        let fy = src_y - y0 as f32;
131
132        for col in 0..target_w {
133            let src_x = (col as f32 * scale_x).min((width - 1) as f32);
134            let x0 = src_x as usize;
135            let x1 = (x0 + 1).min(width - 1);
136            let fx = src_x - x0 as f32;
137
138            for ch in 0..3 {
139                let v00 = rgb8[(y0 * width + x0) * 3 + ch] as f32;
140                let v01 = rgb8[(y0 * width + x1) * 3 + ch] as f32;
141                let v10 = rgb8[(y1 * width + x0) * 3 + ch] as f32;
142                let v11 = rgb8[(y1 * width + x1) * 3 + ch] as f32;
143                let v = v00 * (1.0 - fx) * (1.0 - fy)
144                    + v01 * fx * (1.0 - fy)
145                    + v10 * (1.0 - fx) * fy
146                    + v11 * fx * fy;
147                data.push(v / 255.0);
148            }
149        }
150    }
151
152    Tensor::from_vec(vec![1, target_h, target_w, 3], data).map_err(DetectError::Tensor)
153}