yscv_detect/
model_detector.rs1use yscv_tensor::Tensor;
8
9use crate::{DetectError, Detection};
10
11#[derive(Debug, Clone)]
13pub struct ModelDetectorConfig {
14 pub score_threshold: f32,
16 pub nms_iou_threshold: f32,
18 pub max_detections: usize,
20 pub input_height: usize,
22 pub input_width: usize,
24}
25
26impl Default for ModelDetectorConfig {
27 fn default() -> Self {
28 Self {
29 score_threshold: 0.5,
30 nms_iou_threshold: 0.45,
31 max_detections: 100,
32 input_height: 640,
33 input_width: 640,
34 }
35 }
36}
37
38pub trait ModelDetector {
44 fn detect_tensor(&self, input: &Tensor) -> Result<Vec<Detection>, DetectError>;
52
53 fn class_labels(&self) -> &[&str];
55
56 fn input_shape(&self) -> [usize; 3];
58}
59
60pub fn postprocess_detections(raw: &[Detection], config: &ModelDetectorConfig) -> Vec<Detection> {
64 let mut filtered: Vec<Detection> = raw
66 .iter()
67 .copied()
68 .filter(|d| d.score >= config.score_threshold)
69 .collect();
70
71 filtered.sort_by(|a, b| {
73 b.score
74 .partial_cmp(&a.score)
75 .unwrap_or(std::cmp::Ordering::Equal)
76 });
77
78 let mut result = Vec::new();
80 let mut suppressed = vec![false; filtered.len()];
81
82 for i in 0..filtered.len() {
83 if suppressed[i] {
84 continue;
85 }
86 result.push(filtered[i]);
87 if result.len() >= config.max_detections {
88 break;
89 }
90 for j in i + 1..filtered.len() {
91 if suppressed[j] || filtered[j].class_id != filtered[i].class_id {
92 continue;
93 }
94 if crate::iou(filtered[i].bbox, filtered[j].bbox) > config.nms_iou_threshold {
95 suppressed[j] = true;
96 }
97 }
98 }
99
100 result
101}
102
103pub fn preprocess_rgb8_for_model(
108 rgb8: &[u8],
109 width: usize,
110 height: usize,
111 target_h: usize,
112 target_w: usize,
113) -> Result<Tensor, DetectError> {
114 if rgb8.len() < width * height * 3 {
115 return Err(DetectError::InvalidRgb8BufferSize {
116 expected: width * height * 3,
117 got: rgb8.len(),
118 });
119 }
120
121 let mut data = Vec::with_capacity(target_h * target_w * 3);
123 let scale_y = height as f32 / target_h as f32;
124 let scale_x = width as f32 / target_w as f32;
125
126 for row in 0..target_h {
127 let src_y = (row as f32 * scale_y).min((height - 1) as f32);
128 let y0 = src_y as usize;
129 let y1 = (y0 + 1).min(height - 1);
130 let fy = src_y - y0 as f32;
131
132 for col in 0..target_w {
133 let src_x = (col as f32 * scale_x).min((width - 1) as f32);
134 let x0 = src_x as usize;
135 let x1 = (x0 + 1).min(width - 1);
136 let fx = src_x - x0 as f32;
137
138 for ch in 0..3 {
139 let v00 = rgb8[(y0 * width + x0) * 3 + ch] as f32;
140 let v01 = rgb8[(y0 * width + x1) * 3 + ch] as f32;
141 let v10 = rgb8[(y1 * width + x0) * 3 + ch] as f32;
142 let v11 = rgb8[(y1 * width + x1) * 3 + ch] as f32;
143 let v = v00 * (1.0 - fx) * (1.0 - fy)
144 + v01 * fx * (1.0 - fy)
145 + v10 * (1.0 - fx) * fy
146 + v11 * fx * fy;
147 data.push(v / 255.0);
148 }
149 }
150 }
151
152 Tensor::from_vec(vec![1, target_h, target_w, 3], data).map_err(DetectError::Tensor)
153}