1use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11#[derive(Debug, Clone)]
13pub struct YoloConfig {
14 pub input_size: usize,
16 pub num_classes: usize,
18 pub conf_threshold: f32,
20 pub iou_threshold: f32,
22 pub class_labels: Vec<String>,
24}
25
26#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29 [
30 "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31 "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32 "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33 "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34 "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35 "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36 "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37 "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38 "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39 "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40 "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41 "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42 "scissors", "teddy bear", "hair drier", "toothbrush",
43 ]
44 .iter()
45 .map(|s| (*s).to_string())
46 .collect()
47}
48
49pub fn yolov8_coco_config() -> YoloConfig {
51 YoloConfig {
52 input_size: 640,
53 num_classes: 80,
54 conf_threshold: 0.25,
55 iou_threshold: 0.45,
56 class_labels: coco_labels(),
57 }
58}
59
60pub fn decode_yolov8_output(
72 output: &Tensor,
73 config: &YoloConfig,
74 orig_width: usize,
75 orig_height: usize,
76) -> Vec<Detection> {
77 let shape = output.shape();
78 if shape.len() != 3 || shape[0] != 1 {
80 return Vec::new();
81 }
82 let rows = shape[1]; let num_preds = shape[2];
84 if rows < 5 {
85 return Vec::new();
86 }
87 let num_classes = rows - 4;
88
89 let data = output.data();
90
91 let scale = (config.input_size as f32 / orig_width as f32)
93 .min(config.input_size as f32 / orig_height as f32);
94 let new_w = orig_width as f32 * scale;
95 let new_h = orig_height as f32 * scale;
96 let pad_x = (config.input_size as f32 - new_w) / 2.0;
97 let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99 let mut candidates = Vec::new();
100
101 for i in 0..num_preds {
102 let cx = data[i];
104 let cy = data[num_preds + i];
105 let w = data[2 * num_preds + i];
106 let h = data[3 * num_preds + i];
107
108 let mut best_score = f32::NEG_INFINITY;
110 let mut best_class = 0usize;
111 for c in 0..num_classes {
112 let s = data[(4 + c) * num_preds + i];
113 if s > best_score {
114 best_score = s;
115 best_class = c;
116 }
117 }
118
119 if best_score < config.conf_threshold {
120 continue;
121 }
122
123 let x1 = ((cx - w / 2.0) - pad_x) / scale;
125 let y1 = ((cy - h / 2.0) - pad_y) / scale;
126 let x2 = ((cx + w / 2.0) - pad_x) / scale;
127 let y2 = ((cy + h / 2.0) - pad_y) / scale;
128
129 let x1 = x1.max(0.0).min(orig_width as f32);
131 let y1 = y1.max(0.0).min(orig_height as f32);
132 let x2 = x2.max(0.0).min(orig_width as f32);
133 let y2 = y2.max(0.0).min(orig_height as f32);
134
135 candidates.push(Detection {
136 bbox: BoundingBox { x1, y1, x2, y2 },
137 score: best_score,
138 class_id: best_class,
139 });
140 }
141
142 non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
143}
144
145pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
151 let shape = image.shape();
152 assert!(
153 shape.len() == 3 && shape[2] == 3,
154 "expected [H, W, 3] tensor"
155 );
156 let src_h = shape[0];
157 let src_w = shape[1];
158 let data = image.data();
159
160 let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
161 let new_w = (src_w as f32 * scale).round() as usize;
162 let new_h = (src_h as f32 * scale).round() as usize;
163 let pad_x = (target_size - new_w) as f32 / 2.0;
164 let pad_y = (target_size - new_h) as f32 / 2.0;
165 let pad_left = pad_x.floor() as usize;
166 let pad_top = pad_y.floor() as usize;
167
168 let total = target_size * target_size * 3;
170 let mut out = vec![0.5f32; total];
171
172 let scale_x = src_w as f32 / new_w as f32;
174 let scale_y = src_h as f32 / new_h as f32;
175
176 for y in 0..new_h {
177 let src_y = ((y as f32 * scale_y) as usize).min(src_h - 1);
178 for x in 0..new_w {
179 let src_x = ((x as f32 * scale_x) as usize).min(src_w - 1);
180 let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
181 let src_idx = (src_y * src_w + src_x) * 3;
182 out[dst_idx] = data[src_idx];
183 out[dst_idx + 1] = data[src_idx + 1];
184 out[dst_idx + 2] = data[src_idx + 2];
185 }
186 }
187
188 let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
189 .expect("letterbox output tensor creation");
190 (tensor, scale, pad_x, pad_y)
191}
192
193#[allow(dead_code)]
198fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
199 let shape = hwc.shape();
200 let h = shape[0];
201 let w = shape[1];
202 let data = hwc.data();
203 let mut nchw = vec![0.0f32; 3 * h * w];
204 for y in 0..h {
205 for x in 0..w {
206 let src = (y * w + x) * 3;
207 for c in 0..3 {
208 nchw[c * h * w + y * w + x] = data[src + c];
209 }
210 }
211 }
212 nchw
213}
214
215#[cfg(feature = "onnx")]
221pub fn detect_yolov8_onnx(
222 model: &yscv_onnx::OnnxModel,
223 image_data: &[f32],
224 img_height: usize,
225 img_width: usize,
226 config: &YoloConfig,
227) -> Result<Vec<Detection>, crate::DetectError> {
228 use std::collections::HashMap;
229
230 let input_name = model
231 .inputs
232 .first()
233 .cloned()
234 .unwrap_or_else(|| "images".to_string());
235
236 let tensor = Tensor::from_vec(
237 vec![1, 3, config.input_size, config.input_size],
238 image_data.to_vec(),
239 )?;
240
241 let mut inputs = HashMap::new();
242 inputs.insert(input_name, tensor);
243
244 let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
245
246 let output_name = model
247 .outputs
248 .first()
249 .cloned()
250 .unwrap_or_else(|| "output0".to_string());
251
252 let output_tensor =
253 outputs
254 .get(&output_name)
255 .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
256 node: "model_output".to_string(),
257 input: output_name,
258 })?;
259
260 Ok(decode_yolov8_output(
261 output_tensor,
262 config,
263 img_width,
264 img_height,
265 ))
266}
267
268#[cfg(feature = "onnx")]
274pub fn detect_yolov8_from_rgb(
275 model: &yscv_onnx::OnnxModel,
276 rgb_data: &[f32],
277 height: usize,
278 width: usize,
279 config: &YoloConfig,
280) -> Result<Vec<Detection>, crate::DetectError> {
281 let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
282 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
283
284 let nchw = hwc_to_nchw(&letterboxed);
285
286 detect_yolov8_onnx(model, &nchw, height, width, config)
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_coco_labels_count() {
295 assert_eq!(coco_labels().len(), 80);
296 }
297
298 #[test]
299 fn test_yolov8_coco_config_defaults() {
300 let cfg = yolov8_coco_config();
301 assert_eq!(cfg.input_size, 640);
302 assert_eq!(cfg.num_classes, 80);
303 assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
304 assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
305 assert_eq!(cfg.class_labels.len(), 80);
306 }
307
308 fn make_one_detection_tensor() -> Tensor {
311 let num_classes = 80;
312 let rows = 4 + num_classes;
313 let num_preds = 8400;
314 let mut data = vec![0.0f32; rows * num_preds];
315
316 data[0] = 320.0; data[num_preds] = 320.0; data[2 * num_preds] = 100.0; data[3 * num_preds] = 100.0; data[(4 + 5) * num_preds] = 0.9;
324
325 Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
326 }
327
328 #[test]
329 fn test_decode_yolov8_output_basic() {
330 let tensor = make_one_detection_tensor();
331 let config = YoloConfig {
332 input_size: 640,
333 num_classes: 80,
334 conf_threshold: 0.25,
335 iou_threshold: 0.45,
336 class_labels: coco_labels(),
337 };
338
339 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
341 assert_eq!(dets.len(), 1);
342 assert_eq!(dets[0].class_id, 5);
343 assert!((dets[0].score - 0.9).abs() < 1e-6);
344
345 let b = &dets[0].bbox;
347 assert!((b.x1 - 270.0).abs() < 1.0);
348 assert!((b.y1 - 270.0).abs() < 1.0);
349 assert!((b.x2 - 370.0).abs() < 1.0);
350 assert!((b.y2 - 370.0).abs() < 1.0);
351 }
352
353 #[test]
354 fn test_decode_yolov8_output_confidence_filter() {
355 let tensor = make_one_detection_tensor();
356 let config = YoloConfig {
357 input_size: 640,
358 num_classes: 80,
359 conf_threshold: 0.95, iou_threshold: 0.45,
361 class_labels: coco_labels(),
362 };
363 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
364 assert!(dets.is_empty());
365 }
366
367 #[test]
368 fn test_decode_yolov8_output_nms() {
369 let num_classes = 80;
370 let rows = 4 + num_classes;
371 let num_preds = 8400;
372 let mut data = vec![0.0f32; rows * num_preds];
373
374 data[0] = 320.0;
377 data[num_preds] = 320.0;
378 data[2 * num_preds] = 100.0;
379 data[3 * num_preds] = 100.0;
380 data[4 * num_preds] = 0.9;
381
382 data[1] = 325.0;
384 data[num_preds + 1] = 325.0;
385 data[2 * num_preds + 1] = 100.0;
386 data[3 * num_preds + 1] = 100.0;
387 data[4 * num_preds + 1] = 0.8;
388
389 let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
390 let config = YoloConfig {
391 input_size: 640,
392 num_classes: 80,
393 conf_threshold: 0.25,
394 iou_threshold: 0.45,
395 class_labels: coco_labels(),
396 };
397
398 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
399 assert_eq!(dets.len(), 1);
401 assert!((dets[0].score - 0.9).abs() < 1e-6);
402 }
403
404 #[test]
405 fn test_letterbox_preprocess_square() {
406 let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
408 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
409 assert_eq!(out.shape(), &[640, 640, 3]);
410 assert!((scale - 6.4).abs() < 0.01);
411 assert!(pad_x.abs() < 1.0);
412 assert!(pad_y.abs() < 1.0);
413 }
414
415 #[test]
416 fn test_hwc_to_nchw_basic() {
417 let data = vec![
419 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.0, 0.5, ];
424 let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
425 let nchw = hwc_to_nchw(&img);
426 assert_eq!(nchw.len(), 12);
428 assert!((nchw[0] - 0.1).abs() < 1e-6); assert!((nchw[1] - 0.4).abs() < 1e-6); assert!((nchw[2] - 0.7).abs() < 1e-6); assert!((nchw[3] - 1.0).abs() < 1e-6); assert!((nchw[4] - 0.2).abs() < 1e-6);
435 assert!((nchw[5] - 0.5).abs() < 1e-6);
436 assert!((nchw[6] - 0.8).abs() < 1e-6);
437 assert!((nchw[7] - 0.0).abs() < 1e-6);
438 assert!((nchw[8] - 0.3).abs() < 1e-6);
440 assert!((nchw[9] - 0.6).abs() < 1e-6);
441 assert!((nchw[10] - 0.9).abs() < 1e-6);
442 assert!((nchw[11] - 0.5).abs() < 1e-6);
443 }
444
445 #[test]
446 fn test_letterbox_then_nchw_pipeline() {
447 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
449 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
450 assert_eq!(letterboxed.shape(), &[640, 640, 3]);
451 let nchw = hwc_to_nchw(&letterboxed);
452 assert_eq!(nchw.len(), 3 * 640 * 640);
453 }
454
455 #[test]
456 fn test_letterbox_preprocess_landscape() {
457 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
460 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
461 assert_eq!(out.shape(), &[640, 640, 3]);
462 assert!((scale - 3.2).abs() < 0.01);
463 assert!(pad_x.abs() < 1.0);
464 assert!((pad_y - 160.0).abs() < 1.0);
465
466 let top_pixel = &out.data()[0..3];
468 for &v in top_pixel {
469 assert!((v - 0.5).abs() < 1e-6, "top padding should be 0.5 grey");
470 }
471 }
472}