Skip to main content

ultralytics_inference/
postprocessing.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Post-processing for YOLO model outputs.
4//!
5//! This module handles task-specific post-processing of raw model outputs,
6//! including NMS, coordinate transformation, and result construction.
7
8#![allow(
9    unsafe_code,
10    clippy::doc_markdown,
11    clippy::too_many_lines,
12    clippy::if_not_else,
13    clippy::ptr_as_ptr,
14    clippy::cast_possible_truncation,
15    clippy::cast_sign_loss
16)]
17
18use std::collections::HashMap;
19
20use wide::{CmpGt, f32x8};
21
22use fast_image_resize::images::Image;
23use fast_image_resize::{FilterType, PixelType, ResizeAlg, ResizeOptions, Resizer};
24use ndarray::{Array2, Array3, ArrayView1, ArrayViewMut2, Zip, s};
25
26use crate::inference::InferenceConfig;
27use crate::preprocessing::{PreprocessResult, clip_coords, scale_coords};
28use crate::results::{Boxes, Keypoints, Masks, Obb, Probs, Results, Speed};
29use crate::task::Task;
30use crate::utils::{nms_per_class, nms_rotated_per_class};
31
32/// Post-process raw model output based on task type.
33///
34/// # Arguments
35///
36/// * `outputs` - Vector of raw model outputs (data, shape).
37/// * `task` - The task type (detect, segment, pose, etc.).
38/// * `preprocess` - Preprocessing result containing scale/padding info.
39/// * `config` - Inference configuration.
40/// * `names` - Class ID to name mapping.
41/// * `orig_img` - Original image as HWC array.
42/// * `path` - Source image path.
43/// * `speed` - Timing information.
44/// * `inference_shape` - The inference tensor shape (height, width) after letterboxing.
45///
46/// # Returns
47///
48/// Processed Results object.
49#[must_use]
50#[allow(
51    clippy::too_many_arguments,
52    clippy::similar_names,
53    clippy::implicit_hasher
54)]
55pub fn postprocess(
56    outputs: Vec<(&[f32], Vec<usize>)>,
57    task: Task,
58    preprocess: &PreprocessResult,
59    config: &InferenceConfig,
60    names: &HashMap<usize, String>,
61    orig_img: Array3<u8>,
62    path: String,
63    speed: Speed,
64    inference_shape: (u32, u32),
65    end2end: bool,
66    kpt_shape: Option<(usize, usize)>,
67) -> Results {
68    match task {
69        Task::Detect => {
70            let (output, shape) = &outputs[0];
71            if end2end || is_end2end_detect_shape(shape) {
72                postprocess_detect_end2end(
73                    output,
74                    shape,
75                    preprocess,
76                    config,
77                    names,
78                    orig_img,
79                    path,
80                    speed,
81                    inference_shape,
82                )
83            } else {
84                postprocess_detect(
85                    output,
86                    shape,
87                    preprocess,
88                    config,
89                    names,
90                    orig_img,
91                    path,
92                    speed,
93                    inference_shape,
94                )
95            }
96        }
97        Task::Segment => {
98            // Derive proto channel count from the second output tensor (shape `[1, nm, mh, mw]`),
99            // so exports with non-default prototype counts are still classified correctly.
100            let proto_channels = outputs
101                .get(1)
102                .and_then(|(_, s)| if s.len() == 4 { Some(s[1]) } else { None });
103            if end2end || is_end2end_segment_shape(&outputs[0].1, proto_channels) {
104                postprocess_segment_end2end(
105                    outputs,
106                    preprocess,
107                    config,
108                    names,
109                    orig_img,
110                    path,
111                    speed,
112                    inference_shape,
113                )
114            } else {
115                postprocess_segment(
116                    outputs,
117                    preprocess,
118                    config,
119                    names,
120                    orig_img,
121                    path,
122                    speed,
123                    inference_shape,
124                )
125            }
126        }
127        Task::Pose => {
128            let (output, shape) = &outputs[0];
129            // Prefer metadata-provided `kpt_shape`; otherwise infer the keypoint
130            // layout from the tensor shape so non-COCO pose models still work.
131            let resolved_kpt = kpt_shape.or_else(|| infer_end2end_kpt_shape(shape));
132            let is_end2end = end2end
133                || resolved_kpt.is_some_and(|(nk, kd)| is_end2end_pose_shape(shape, nk, kd));
134            if is_end2end {
135                let (nk, kpt_dim) = resolved_kpt.unwrap_or((17, 3));
136                postprocess_pose_end2end(
137                    output,
138                    shape,
139                    preprocess,
140                    config,
141                    names,
142                    orig_img,
143                    path,
144                    speed,
145                    inference_shape,
146                    nk,
147                    kpt_dim,
148                )
149            } else {
150                postprocess_pose(
151                    output,
152                    shape,
153                    preprocess,
154                    config,
155                    names,
156                    orig_img,
157                    path,
158                    speed,
159                    inference_shape,
160                )
161            }
162        }
163        Task::Classify => {
164            let (output, _) = &outputs[0];
165            postprocess_classify(output, names, orig_img, path, speed, inference_shape)
166        }
167        Task::Obb => {
168            let (output, shape) = &outputs[0];
169            if end2end || is_end2end_obb_shape(shape) {
170                postprocess_obb_end2end(
171                    output,
172                    shape,
173                    preprocess,
174                    config,
175                    names,
176                    orig_img,
177                    path,
178                    speed,
179                    inference_shape,
180                )
181            } else {
182                postprocess_obb(
183                    output,
184                    shape,
185                    preprocess,
186                    config,
187                    names,
188                    orig_img,
189                    path,
190                    speed,
191                    inference_shape,
192                )
193            }
194        }
195    }
196}
197
198/// Detect if a 3D shape matches YOLO26 end2end detect layout `[1, max_det, 6]`.
199fn is_end2end_detect_shape(shape: &[usize]) -> bool {
200    shape.len() == 3 && shape[2] == 6 && shape[1] <= 4096
201}
202
203/// Detect if a 3D shape matches YOLO26 end2end segment layout `[1, max_det, 6 + nm]`.
204///
205/// When the proto tensor's channel count is known (passed via `proto_channels`),
206/// it is used as the authoritative `nm`; otherwise the shape is rejected since
207/// hardcoding `nm=32` would misclassify exports with non-default prototype counts.
208fn is_end2end_segment_shape(shape: &[usize], proto_channels: Option<usize>) -> bool {
209    proto_channels.is_some_and(|nm| shape.len() == 3 && shape[2] == 6 + nm && shape[1] <= 4096)
210}
211
212/// Detect if a 3D shape matches YOLO26 end2end pose layout `[1, max_det, 6 + nk*dim]`.
213fn is_end2end_pose_shape(shape: &[usize], nk: usize, kpt_dim: usize) -> bool {
214    shape.len() == 3 && shape[2] == 6 + nk * kpt_dim && shape[1] <= 4096
215}
216
217/// Infer `(nk, kpt_dim)` from a pose tensor shape assumed to be end-to-end
218/// (`[1, max_det, 6 + nk*kpt_dim]`). Used only when `kpt_shape` metadata is absent.
219///
220/// Returns `None` for shapes that don't match the end-to-end layout, and for
221/// shapes where `kpt_feats` is divisible by both 2 and 3 (e.g. 36, 42) because
222/// the layout is ambiguous: `(12, 3)` and `(18, 2)` both decode the same
223/// tensor. Ambiguous cases fall through to the legacy pose path rather than
224/// silently guessing the wrong dimension.
225fn infer_end2end_kpt_shape(shape: &[usize]) -> Option<(usize, usize)> {
226    if shape.len() != 3 || shape[1] == 0 || shape[1] > 4096 || shape[2] <= 6 {
227        return None;
228    }
229    let kpt_feats = shape[2] - 6;
230    let div3 = kpt_feats.is_multiple_of(3);
231    let div2 = kpt_feats.is_multiple_of(2);
232    match (div3, div2) {
233        (true, false) => Some((kpt_feats / 3, 3)),
234        (false, true) => Some((kpt_feats / 2, 2)),
235        _ => None, // not divisible by either, or ambiguous (divisible by 6)
236    }
237}
238
239/// Detect if a 3D shape matches YOLO26 end2end OBB layout `[1, max_det, 7]`.
240fn is_end2end_obb_shape(shape: &[usize]) -> bool {
241    shape.len() == 3 && shape[2] == 7 && shape[1] <= 4096
242}
243
244/// Post-process detection model output.
245///
246/// Zero-copy implementation using stride-based indexing to avoid memory allocations.
247#[allow(
248    clippy::too_many_arguments,
249    clippy::similar_names,
250    clippy::cast_precision_loss
251)]
252fn postprocess_detect(
253    output: &[f32],
254    output_shape: &[usize],
255    preprocess: &PreprocessResult,
256    config: &InferenceConfig,
257    names: &HashMap<usize, String>,
258    orig_img: Array3<u8>,
259    path: String,
260    speed: Speed,
261    inference_shape: (u32, u32),
262) -> Results {
263    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
264
265    // Parse output shape - handle both [1, 84, 8400] and [1, 8400, 84] formats
266    let (num_classes, num_predictions, is_transposed) =
267        parse_detect_shape(output_shape, names.len());
268
269    if output.is_empty() || num_predictions == 0 {
270        return results;
271    }
272
273    // Zero-copy extraction with stride-based indexing
274    let boxes_data = extract_detect_boxes(
275        output,
276        num_classes,
277        num_predictions,
278        is_transposed,
279        preprocess,
280        config,
281    );
282
283    if !boxes_data.is_empty() {
284        results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
285    }
286
287    results
288}
289
290/// Parse detection output shape to determine format.
291///
292/// Derives class count from output shape when metadata is missing (`expected_classes` == 0).
293/// YOLO outputs are either [1, `num_features`, `num_preds`] or [1, `num_preds`, `num_features`]
294/// where `num_features` = 4 (bbox) + `num_classes`.
295fn parse_detect_shape(shape: &[usize], expected_classes: usize) -> (usize, usize, bool) {
296    match shape.len() {
297        2 => {
298            // [num_preds, num_features] or [num_features, num_preds]
299            let (a, b) = (shape[0], shape[1]);
300            // Handle edge case where either dimension is less than 4
301            if a < 4 && b < 4 {
302                return (expected_classes.max(1), 0, false);
303            }
304            // When metadata is missing, infer from shape:
305            // The smaller dimension (if >= 5) is likely num_features, larger is num_preds
306            if expected_classes == 0 {
307                // No metadata - infer from shape
308                let (num_features, num_preds, transposed) =
309                    if a < b { (a, b, false) } else { (b, a, true) };
310                let inferred_classes = num_features.saturating_sub(4);
311                return (inferred_classes.max(1), num_preds, transposed);
312            }
313            if a == 4 + expected_classes || (a >= 4 && a > b) {
314                // [num_features, num_preds]
315                (a.saturating_sub(4), b, false)
316            } else {
317                // [num_preds, num_features]
318                (b.saturating_sub(4), a, true)
319            }
320        }
321        3 => {
322            // [batch, ...] - ignore batch dimension
323            let (a, b) = (shape[1], shape[2]);
324            // Handle edge case where num_predictions is 0 or very small
325            if b == 0 || a < 4 {
326                return (expected_classes.max(1), 0, false);
327            }
328            // When metadata is missing, infer from shape
329            if expected_classes == 0 {
330                // No metadata - infer from shape
331                // Typically num_features < num_preds (e.g., 84 < 8400)
332                let (num_features, num_preds, transposed) =
333                    if a < b { (a, b, false) } else { (b, a, true) };
334                let inferred_classes = num_features.saturating_sub(4);
335                return (inferred_classes.max(1), num_preds, transposed);
336            }
337            if a == 4 + expected_classes || (expected_classes > 0 && a < b) {
338                // [1, num_features, num_preds]
339                (a.saturating_sub(4), b, false)
340            } else {
341                // [1, num_preds, num_features]
342                (b.saturating_sub(4), a, true)
343            }
344        }
345        _ => (expected_classes.max(1), 0, false),
346    }
347}
348
349#[derive(Clone, Copy)]
350struct Candidate {
351    bbox: [f32; 4],
352    score: f32,
353    class: usize,
354}
355
356/// Optimized detection extraction with SIMD acceleration.
357///
358/// Key optimizations:
359/// - SIMD-accelerated candidate extraction (f32x8)
360/// - Parallel Bitmask NMS (IoU 1 vs 8)
361/// - Struct-of-Arrays (SoA) layout for NMS cache locality
362/// - Direct unsafe indexing for performance
363#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)]
364fn extract_detect_boxes(
365    output: &[f32],
366    num_classes: usize,
367    num_predictions: usize,
368    is_transposed: bool,
369    preprocess: &PreprocessResult,
370    config: &InferenceConfig,
371) -> Array2<f32> {
372    let feat_count = 4 + num_classes;
373    let (scale_y, scale_x) = preprocess.scale;
374    let (pad_top, pad_left) = preprocess.padding;
375    let orig_shape = preprocess.orig_shape;
376    let (max_w, max_h) = (orig_shape.1 as f32, orig_shape.0 as f32);
377    let conf_thresh = config.confidence_threshold;
378    let max_det = config.max_det;
379    let iou_thresh = config.iou_threshold;
380    let conf_v = f32x8::splat(conf_thresh);
381
382    let mut candidates: Vec<Candidate> = Vec::with_capacity(256);
383
384    // Candidate Extraction
385    if !is_transposed {
386        // Layout [feat, pred] - Cache-friendly linear scan
387        let mut max_scores = vec![conf_thresh; num_predictions];
388        let mut max_classes = vec![0usize; num_predictions];
389
390        for c in 0..num_classes {
391            let offset = (4 + c) * num_predictions;
392            let class_scores = &output[offset..offset + num_predictions];
393            for (idx, &score) in class_scores.iter().enumerate() {
394                if score > max_scores[idx] {
395                    max_scores[idx] = score;
396                    max_classes[idx] = c;
397                }
398            }
399        }
400
401        for (idx, &score) in max_scores.iter().enumerate() {
402            if score > conf_thresh {
403                let best_class = max_classes[idx];
404
405                // Filter by class if specified
406                if !config.keep_class(best_class) {
407                    continue;
408                }
409
410                let cx = unsafe { *output.get_unchecked(idx) };
411                let cy = unsafe { *output.get_unchecked(num_predictions + idx) };
412                let w = unsafe { *output.get_unchecked(2 * num_predictions + idx) };
413                let h = unsafe { *output.get_unchecked(3 * num_predictions + idx) };
414
415                let x1 = (cx - w * 0.5 - pad_left) / scale_x;
416                let y1 = (cy - h * 0.5 - pad_top) / scale_y;
417                let x2 = (cx + w * 0.5 - pad_left) / scale_x;
418                let y2 = (cy + h * 0.5 - pad_top) / scale_y;
419
420                candidates.push(Candidate {
421                    bbox: [x1, y1, x2, y2],
422                    score,
423                    class: best_class,
424                });
425            }
426        }
427    } else {
428        // Layout [pred, feat] - Process 8 classes at once
429        for idx in 0..num_predictions {
430            let base = idx * feat_count;
431            let row_ptr = unsafe { output.as_ptr().add(base + 4) };
432            let mut best_score = conf_thresh;
433            let mut best_class = 0;
434
435            for c_idx in (0..num_classes).step_by(8) {
436                if num_classes - c_idx >= 8 {
437                    let scores: f32x8 =
438                        unsafe { (row_ptr.add(c_idx) as *const f32x8).read_unaligned() };
439                    if scores.simd_gt(conf_v).any() {
440                        for i in 0..8 {
441                            let s = unsafe { *row_ptr.add(c_idx + i) };
442                            if s > best_score {
443                                best_score = s;
444                                best_class = c_idx + i;
445                            }
446                        }
447                    }
448                } else {
449                    for i in c_idx..num_classes {
450                        let s = unsafe { *row_ptr.add(i) };
451                        if s > best_score {
452                            best_score = s;
453                            best_class = i;
454                        }
455                    }
456                }
457            }
458
459            if best_score > conf_thresh {
460                // Filter by class if specified
461                if !config.keep_class(best_class) {
462                    continue;
463                }
464
465                let cx = unsafe { *output.get_unchecked(base) };
466                let cy = unsafe { *output.get_unchecked(base + 1) };
467                let w = unsafe { *output.get_unchecked(base + 2) };
468                let h = unsafe { *output.get_unchecked(base + 3) };
469
470                let x1 = (cx - w * 0.5 - pad_left) / scale_x;
471                let y1 = (cy - h * 0.5 - pad_top) / scale_y;
472                let x2 = (cx + w * 0.5 - pad_left) / scale_x;
473                let y2 = (cy + h * 0.5 - pad_top) / scale_y;
474
475                candidates.push(Candidate {
476                    bbox: [x1, y1, x2, y2],
477                    score: best_score,
478                    class: best_class,
479                });
480            }
481        }
482    }
483
484    if candidates.is_empty() {
485        return Array2::zeros((0, 6));
486    }
487
488    // Top-K Selection & Sort
489    let nms_limit = (max_det * 10).min(candidates.len());
490    if candidates.len() > nms_limit {
491        candidates.select_nth_unstable_by(nms_limit, |a, b| b.score.partial_cmp(&a.score).unwrap());
492        candidates.truncate(nms_limit);
493    }
494    candidates.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
495
496    // Population of SoA for NMS (small copy, very fast)
497    let n = candidates.len();
498    let mut x1 = Vec::with_capacity(n);
499    let mut y1 = Vec::with_capacity(n);
500    let mut x2 = Vec::with_capacity(n);
501    let mut y2 = Vec::with_capacity(n);
502    let mut areas = Vec::with_capacity(n);
503
504    for c in &candidates {
505        x1.push(c.bbox[0]);
506        y1.push(c.bbox[1]);
507        x2.push(c.bbox[2]);
508        y2.push(c.bbox[3]);
509        areas.push((c.bbox[2] - c.bbox[0]) * (c.bbox[3] - c.bbox[1]));
510    }
511
512    let mut suppressed = vec![false; n];
513    let mut keep = Vec::with_capacity(max_det);
514    let iou_v = f32x8::splat(iou_thresh);
515    for i in 0..n {
516        if suppressed[i] {
517            continue;
518        }
519        keep.push(i);
520        if keep.len() >= max_det {
521            break;
522        }
523
524        let ax1 = f32x8::splat(x1[i]);
525        let ay1 = f32x8::splat(y1[i]);
526        let ax2 = f32x8::splat(x2[i]);
527        let ay2 = f32x8::splat(y2[i]);
528        let aa = f32x8::splat(areas[i]);
529        let ac = candidates[i].class;
530
531        let mut j = i + 1;
532        while j < n {
533            if n - j >= 8 {
534                if (0..8).any(|k| candidates[j + k].class == ac && !suppressed[j + k]) {
535                    let bx1 = unsafe { (x1.as_ptr().add(j) as *const f32x8).read_unaligned() };
536                    let by1 = unsafe { (y1.as_ptr().add(j) as *const f32x8).read_unaligned() };
537                    let bx2 = unsafe { (x2.as_ptr().add(j) as *const f32x8).read_unaligned() };
538                    let by2 = unsafe { (y2.as_ptr().add(j) as *const f32x8).read_unaligned() };
539                    let ba = unsafe { (areas.as_ptr().add(j) as *const f32x8).read_unaligned() };
540
541                    let ix1 = ax1.max(bx1);
542                    let iy1 = ay1.max(by1);
543                    let ix2 = ax2.min(bx2);
544                    let iy2 = ay2.min(by2);
545
546                    let iw = (ix2 - ix1).max(f32x8::ZERO);
547                    let ih = (iy2 - iy1).max(f32x8::ZERO);
548                    let ia = iw * ih;
549                    let iou = ia / (aa + ba - ia);
550
551                    let mask = iou.simd_gt(iou_v).to_bitmask() as u8;
552                    if mask != 0 {
553                        for k in 0..8 {
554                            if (mask & (1 << k)) != 0 && candidates[j + k].class == ac {
555                                suppressed[j + k] = true;
556                            }
557                        }
558                    }
559                }
560                j += 8;
561            } else {
562                for k in j..n {
563                    if !suppressed[k] && candidates[k].class == ac {
564                        let ix1 = x1[i].max(x1[k]);
565                        let iy1 = y1[i].max(y1[k]);
566                        let ix2 = x2[i].min(x2[k]);
567                        let iy2 = y2[i].min(y2[k]);
568                        let iw = (ix2 - ix1).max(0.0);
569                        let ih = (iy2 - iy1).max(0.0);
570                        let ia = iw * ih;
571                        let iou = ia / (areas[i] + areas[k] - ia);
572                        if iou > iou_thresh {
573                            suppressed[k] = true;
574                        }
575                    }
576                }
577                break;
578            }
579        }
580    }
581    // Result Construction
582    let num_kept = keep.len();
583    let mut result = Array2::zeros((num_kept, 6));
584    for (out_idx, &idx) in keep.iter().enumerate() {
585        let c = &candidates[idx];
586        result[[out_idx, 0]] = c.bbox[0].clamp(0.0, max_w);
587        result[[out_idx, 1]] = c.bbox[1].clamp(0.0, max_h);
588        result[[out_idx, 2]] = c.bbox[2].clamp(0.0, max_w);
589        result[[out_idx, 3]] = c.bbox[3].clamp(0.0, max_h);
590        result[[out_idx, 4]] = c.score;
591        result[[out_idx, 5]] = c.class as f32;
592    }
593
594    result
595}
596
597/// Post-process segmentation model output.
598///
599/// Generates bounding boxes and segmentation masks from the model output.
600///
601/// # Arguments
602///
603/// * `outputs` - Vector of model outputs (detection features and mask prototypes).
604/// * `preprocess` - Preprocessing metadata.
605/// * `config` - Inference configuration.
606/// * `names` - Class mapping.
607/// * `orig_img` - Original image.
608/// * `path` - Source path.
609/// * `speed` - Timing metrics.
610/// * `inference_shape` - Inference input dimensions.
611///
612/// # Returns
613///
614/// `Results` struct containing boxes and masks.
615#[allow(
616    clippy::too_many_arguments,
617    clippy::similar_names,
618    clippy::cast_precision_loss,
619    clippy::too_many_lines,
620    clippy::needless_pass_by_value,
621    clippy::manual_let_else,
622    clippy::cast_possible_truncation
623)]
624fn postprocess_segment(
625    outputs: Vec<(&[f32], Vec<usize>)>,
626    preprocess: &PreprocessResult,
627    config: &InferenceConfig,
628    names: &HashMap<usize, String>,
629    orig_img: Array3<u8>,
630    path: String,
631    speed: Speed,
632    inference_shape: (u32, u32),
633) -> Results {
634    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
635
636    if outputs.len() < 2 {
637        // Protos output missing - log warning for user visibility
638        eprintln!(
639            "WARNING ⚠️ Segmentation model missing protos output (expected 2 outputs, got {}). Returning empty masks.",
640            outputs.len()
641        );
642        return results;
643    }
644
645    let (output0, shape0) = &outputs[0];
646    let (output1, shape1) = &outputs[1];
647
648    // output0: [1, 4 + nc + 32, 8400]
649    // output1: [1, 32, 160, 160] (protos)
650
651    // 1. Process Detections
652    // Standard segmentation models use 32 mask prototypes
653    let num_masks = 32;
654    let expected_features = 4 + names.len() + num_masks;
655
656    // Manual shape check
657    let (num_preds, is_transposed) = if shape0.len() == 3 {
658        let (a, b) = (shape0[1], shape0[2]);
659        if a == expected_features {
660            (b, false) // [1, features, preds]
661        } else if b == expected_features {
662            (a, true) // [1, preds, features]
663        } else {
664            // Assume format [1, 116, 8400] if ambiguous
665            if a < b { (b, false) } else { (a, true) }
666        }
667    } else {
668        (0, false)
669    };
670
671    if output0.is_empty() || num_preds == 0 {
672        return results;
673    }
674
675    // Convert to 2D [preds, features]
676    let output_2d = if is_transposed {
677        Array2::from_shape_vec((num_preds, expected_features), output0.to_vec())
678            .unwrap_or_else(|_| Array2::zeros((0, 0)))
679    } else {
680        let arr = Array2::from_shape_vec((expected_features, num_preds), output0.to_vec())
681            .unwrap_or_else(|_| Array2::zeros((0, 0)));
682        arr.t().to_owned()
683    };
684
685    // Filter and NMS
686    let mut candidates = Vec::new(); // (bbox, score, class, original_index)
687
688    for i in 0..num_preds {
689        let scores = output_2d.slice(s![i, 4..4 + names.len()]);
690        let (best_class, best_score) = scores
691            .iter()
692            .enumerate()
693            .max_by(|&(_, a), &(_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
694            .map_or((0, 0.0), |(idx, &score)| (idx, score));
695
696        if best_score < config.confidence_threshold {
697            continue;
698        }
699
700        // Box
701        let cx = output_2d[[i, 0]];
702        let cy = output_2d[[i, 1]];
703        let w = output_2d[[i, 2]];
704        let h = output_2d[[i, 3]];
705        let x1 = cx - w / 2.0;
706        let y1 = cy - h / 2.0;
707        let x2 = cx + w / 2.0;
708        let y2 = cy + h / 2.0;
709
710        let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
711        let clipped = clip_coords(&scaled, preprocess.orig_shape);
712
713        // Filter by class if specified
714        if !config.keep_class(best_class) {
715            continue;
716        }
717
718        candidates.push((
719            [clipped[0], clipped[1], clipped[2], clipped[3]],
720            best_score,
721            best_class,
722            i, // Keep index to get coefficients
723        ));
724    }
725
726    if candidates.is_empty() {
727        return results;
728    }
729
730    // Prepare candidates for NMS (bbox, score, class)
731    let nms_candidates: Vec<_> = candidates
732        .iter()
733        .map(|(bbox, score, class, _)| (*bbox, *score, *class))
734        .collect();
735
736    let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
737    let num_kept = keep_indices.len().min(config.max_det);
738
739    // 2. Extract Box Results
740    let mut boxes_data = Array2::zeros((num_kept, 6));
741    let mut mask_coeffs = Array2::zeros((num_kept, num_masks));
742
743    for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
744        let (bbox, score, class, orig_idx) = &candidates[keep_idx];
745        boxes_data[[out_idx, 0]] = bbox[0];
746        boxes_data[[out_idx, 1]] = bbox[1];
747        boxes_data[[out_idx, 2]] = bbox[2];
748        boxes_data[[out_idx, 3]] = bbox[3];
749        boxes_data[[out_idx, 4]] = *score;
750        boxes_data[[out_idx, 5]] = *class as f32;
751
752        // Extract coefficients: [orig_idx, 4+nc..]
753        let start = 4 + names.len();
754        let coeffs = output_2d.slice(s![*orig_idx, start..start + num_masks]);
755        for m in 0..num_masks {
756            mask_coeffs[[out_idx, m]] = coeffs[m];
757        }
758    }
759
760    results.boxes = Some(Boxes::new(boxes_data.clone(), preprocess.orig_shape));
761
762    // 3. Process Masks
763    // Protos: [1, 32, 160, 160] -> [32, 25600]
764    // Validate protos shape before indexing to prevent panic
765    if shape1.len() < 4 {
766        eprintln!(
767            "WARNING ⚠️ Protos output has unexpected shape (expected 4 dims, got {}). Skipping mask generation.",
768            shape1.len()
769        );
770        return results;
771    }
772    let mh = shape1[2];
773    let mw = shape1[3];
774
775    // Validate expected mask dimensions match
776    if shape1[1] != num_masks {
777        eprintln!(
778            "WARNING ⚠️ Protos output has {} mask channels, expected {}. Mask quality may be affected.",
779            shape1[1], num_masks
780        );
781    }
782
783    let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
784        Ok(arr) => arr,
785        Err(e) => {
786            eprintln!("WARNING ⚠️ Failed to create protos array: {e}. Skipping mask generation.");
787            return results;
788        }
789    };
790
791    // Matrix Mul: [N, 32] x [32, 25600] -> [N, 25600]
792    let masks_flat = mask_coeffs.dot(&protos);
793
794    // Resize and crop to original image size
795    let (oh, ow) = preprocess.orig_shape;
796    let (th, tw) = inference_shape;
797    let (pad_top, pad_left) = preprocess.padding;
798
799    // Pre-calculate crop parameters (same for all masks)
800    let scale_w = mw as f32 / tw as f32;
801    let scale_h = mh as f32 / th as f32;
802    let crop_x = pad_left * scale_w;
803    let crop_y = pad_top * scale_h;
804    let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
805    let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
806
807    // Initialize output array
808    let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
809
810    // Process each mask in parallel using Rayon and ndarray::Zip.
811    // Each thread handles resizing and cropping for one mask.
812    //
813    // Inputs:
814    // - mask_out: Mutable view into output masks array
815    // - mask_flat: Mask coefficients for the detection
816    // - box_data: Bounding box for the detection
817    Zip::from(masks_data.outer_iter_mut())
818        .and(masks_flat.outer_iter())
819        .and(boxes_data.outer_iter())
820        .par_for_each(
821            |mut mask_out: ArrayViewMut2<f32>,
822             mask_flat: ArrayView1<f32>,
823             box_data: ArrayView1<f32>| {
824                // Create a local resizer for each task (Resizer is not Sync)
825                let mut resizer = Resizer::new();
826                let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
827
828                // Sigmoid into a Vec<f32>
829                let f32_data: Vec<f32> = mask_flat
830                    .iter()
831                    .map(|&val| 1.0 / (1.0 + (-val).exp()))
832                    .collect();
833
834                // Use bytemuck for efficient f32->bytes conversion
835                let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
836
837                // Create source image (160x160)
838                let src_image = match Image::from_vec_u8(
839                    mw as u32,
840                    mh as u32,
841                    src_bytes.to_vec(),
842                    PixelType::F32,
843                ) {
844                    Ok(img) => img,
845                    Err(_) => return, // Skip if creation fails
846                };
847
848                // Create dest image (orig_w x orig_h)
849                let mut dst_image = Image::new(ow, oh, PixelType::F32);
850
851                // Configure resize with crop
852                let safe_crop_x = f64::from(crop_x.max(0.0));
853                let safe_crop_y = f64::from(crop_y.max(0.0));
854                let safe_crop_w = f64::from(crop_w.max(1.0).min(mw as f32));
855                let safe_crop_h = f64::from(crop_h.max(1.0).min(mh as f32));
856
857                let options = ResizeOptions::new().resize_alg(resize_alg).crop(
858                    safe_crop_x,
859                    safe_crop_y,
860                    safe_crop_w,
861                    safe_crop_h,
862                );
863
864                // Handle resize errors gracefully
865                if resizer
866                    .resize(&src_image, &mut dst_image, &options)
867                    .is_err()
868                {
869                    return;
870                }
871
872                // Get resized data as f32 slice
873                let dst_bytes = dst_image.buffer();
874                let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
875
876                // Apply bbox cropping and store directly to output array
877                let x1 = box_data[0].max(0.0).min(ow as f32);
878                let y1 = box_data[1].max(0.0).min(oh as f32);
879                let x2 = box_data[2].max(0.0).min(ow as f32);
880                let y2 = box_data[3].max(0.0).min(oh as f32);
881
882                for y in 0..oh as usize {
883                    for x in 0..ow as usize {
884                        let val = dst_slice[y * ow as usize + x];
885                        let x_f = x as f32;
886                        let y_f = y as f32;
887                        // Apply bounding box mask: invalid pixels outside the box are zeroed.
888                        if x_f >= x1 && x_f <= x2 && y_f >= y1 && y_f <= y2 {
889                            mask_out[[y, x]] = val;
890                        }
891                    }
892                }
893            },
894        );
895
896    results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
897
898    results
899}
900
901/// Post-process pose estimation model output.
902///
903/// Extracts bounding boxes and keypoints (skeleton) from the model output.
904///
905/// # Arguments
906///
907/// * `output` - Flat vector of model output.
908/// * `output_shape` - Output tensor dimensions.
909/// * `preprocess` - Preprocessing metadata.
910/// * `config` - Inference configuration.
911/// * `names` - Class name mapping.
912/// * `orig_img` - Original image.
913/// * `path` - Source image path.
914/// * `speed` - Timing data.
915/// * `inference_shape` - Inference input dimensions.
916///
917/// # Returns
918///
919/// `Results` struct containing boxes and keypoints.
920#[allow(
921    clippy::too_many_arguments,
922    clippy::too_many_lines,
923    clippy::similar_names,
924    clippy::type_complexity,
925    clippy::cast_precision_loss,
926    clippy::doc_lazy_continuation
927)]
928fn postprocess_pose(
929    output: &[f32],
930    output_shape: &[usize],
931    preprocess: &PreprocessResult,
932    config: &InferenceConfig,
933    names: &HashMap<usize, String>,
934    orig_img: Array3<u8>,
935    path: String,
936    speed: Speed,
937    inference_shape: (u32, u32),
938) -> Results {
939    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
940
941    // Standard COCO pose has 17 keypoints, each with (x, y, conf)
942    let num_keypoints = 17;
943    let kpt_dim = 3; // x, y, visibility/confidence
944    let kpt_features = num_keypoints * kpt_dim; // 51
945
946    // Pose typically has 1 class (person), so features = 4 + 1 + 51 = 56
947    let num_classes = names.len().max(1);
948    let expected_features = 4 + num_classes + kpt_features;
949
950    // Parse output shape
951    let (num_preds, is_transposed) = if output_shape.len() == 3 {
952        let (a, b) = (output_shape[1], output_shape[2]);
953        if a == expected_features || (a < b && a >= 4 + kpt_features) {
954            (b, false) // [1, features, preds]
955        } else {
956            (a, true) // [1, preds, features]
957        }
958    } else if output_shape.len() == 2 {
959        let (a, b) = (output_shape[0], output_shape[1]);
960        if a < b { (b, false) } else { (a, true) }
961    } else {
962        (0, false)
963    };
964
965    if output.is_empty() || num_preds == 0 {
966        return results;
967    }
968
969    // Infer actual feature count from data
970    let actual_features = output.len() / num_preds;
971    if actual_features < 4 + kpt_features {
972        eprintln!(
973            "WARNING ⚠️ Pose model has insufficient features ({actual_features}), expected at least {}",
974            4 + kpt_features
975        );
976        return results;
977    }
978
979    // Convert to 2D [preds, features]
980    let output_2d = if is_transposed {
981        Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
982            .unwrap_or_else(|_| Array2::zeros((0, 0)))
983    } else {
984        let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
985            .unwrap_or_else(|_| Array2::zeros((0, 0)));
986        arr.t().to_owned()
987    };
988
989    if output_2d.is_empty() {
990        return results;
991    }
992
993    // Derive number of classes from actual features
994    let derived_classes = actual_features.saturating_sub(4 + kpt_features);
995    let num_classes = derived_classes.max(1);
996
997    // Filter and NMS - store candidates with keypoints
998    let mut candidates: Vec<([f32; 4], f32, usize, Vec<[f32; 3]>)> = Vec::new();
999
1000    for i in 0..num_preds {
1001        // Get class score(s) - for pose, typically just "person" class
1002        let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
1003        let (best_class, best_score) = class_scores
1004            .iter()
1005            .enumerate()
1006            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
1007            .map_or((0, 0.0), |(idx, &score)| {
1008                (idx, if score.is_nan() { 0.0 } else { score })
1009            });
1010
1011        if best_score < config.confidence_threshold {
1012            continue;
1013        }
1014
1015        // Extract box coordinates (xywh format)
1016        let cx = output_2d[[i, 0]];
1017        let cy = output_2d[[i, 1]];
1018        let w = output_2d[[i, 2]];
1019        let h = output_2d[[i, 3]];
1020
1021        // Convert to xyxy
1022        let x1 = cx - w / 2.0;
1023        let y1 = cy - h / 2.0;
1024        let x2 = cx + w / 2.0;
1025        let y2 = cy + h / 2.0;
1026
1027        // Scale box to original image space
1028        let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
1029        let clipped = clip_coords(&scaled, preprocess.orig_shape);
1030
1031        // Extract keypoints (after class scores)
1032        let kpt_start = 4 + num_classes;
1033        let mut keypoints = Vec::with_capacity(num_keypoints);
1034        for k in 0..num_keypoints {
1035            let kpt_offset = kpt_start + k * kpt_dim;
1036            let kpt_x = output_2d[[i, kpt_offset]];
1037            let kpt_y = output_2d[[i, kpt_offset + 1]];
1038            let kpt_conf = output_2d[[i, kpt_offset + 2]];
1039
1040            // Scale keypoint coordinates to original image space
1041            let scaled_kpt = scale_coords(
1042                &[kpt_x, kpt_y, kpt_x, kpt_y],
1043                preprocess.scale,
1044                preprocess.padding,
1045            );
1046            let (oh, ow) = preprocess.orig_shape;
1047            #[allow(clippy::cast_precision_loss)]
1048            let scaled_x = scaled_kpt[0].max(0.0).min(ow as f32);
1049            #[allow(clippy::cast_precision_loss)]
1050            let scaled_y = scaled_kpt[1].max(0.0).min(oh as f32);
1051
1052            keypoints.push([scaled_x, scaled_y, kpt_conf]);
1053        }
1054
1055        // Filter by class if specified
1056        if !config.keep_class(best_class) {
1057            continue;
1058        }
1059
1060        candidates.push((
1061            [clipped[0], clipped[1], clipped[2], clipped[3]],
1062            best_score,
1063            best_class,
1064            keypoints,
1065        ));
1066    }
1067
1068    if candidates.is_empty() {
1069        results.keypoints = Some(Keypoints::new(
1070            Array3::zeros((0, num_keypoints, kpt_dim)),
1071            preprocess.orig_shape,
1072        ));
1073        return results;
1074    }
1075
1076    // Apply NMS
1077    let nms_candidates: Vec<_> = candidates
1078        .iter()
1079        .map(|(bbox, score, class, _)| (*bbox, *score, *class))
1080        .collect();
1081    let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
1082    let num_kept = keep_indices.len().min(config.max_det);
1083
1084    // Build output arrays
1085    let mut boxes_data = Array2::zeros((num_kept, 6));
1086    let mut keypoints_data = Array3::zeros((num_kept, num_keypoints, kpt_dim));
1087
1088    for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
1089        let (bbox, score, class, kpts) = &candidates[keep_idx];
1090
1091        // Store box data
1092        boxes_data[[out_idx, 0]] = bbox[0];
1093        boxes_data[[out_idx, 1]] = bbox[1];
1094        boxes_data[[out_idx, 2]] = bbox[2];
1095        boxes_data[[out_idx, 3]] = bbox[3];
1096        boxes_data[[out_idx, 4]] = *score;
1097        #[allow(clippy::cast_precision_loss)]
1098        let class_f32 = *class as f32;
1099        boxes_data[[out_idx, 5]] = class_f32;
1100
1101        // Store keypoints
1102        for (k, kpt) in kpts.iter().enumerate() {
1103            keypoints_data[[out_idx, k, 0]] = kpt[0]; // x
1104            keypoints_data[[out_idx, k, 1]] = kpt[1]; // y
1105            keypoints_data[[out_idx, k, 2]] = kpt[2]; // confidence
1106        }
1107    }
1108
1109    results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1110    results.keypoints = Some(Keypoints::new(keypoints_data, preprocess.orig_shape));
1111
1112    results
1113}
1114
1115/// Post-process classification model output.
1116///
1117/// Computes best class predictions and probabilities.
1118///
1119/// # Arguments
1120///
1121/// * `output` - Raw model output vector.
1122/// * `names` - Class name mapping.
1123/// * `orig_img` - Original image.
1124/// * `path` - Source path.
1125/// * `speed` - Timing metrics.
1126/// * `inference_shape` - Inference dimensions.
1127///
1128/// # Returns
1129///
1130/// `Results` struct containing classification probabilities.
1131fn postprocess_classify(
1132    output: &[f32],
1133    names: &HashMap<usize, String>,
1134    orig_img: Array3<u8>,
1135    path: String,
1136    speed: Speed,
1137    inference_shape: (u32, u32),
1138) -> Results {
1139    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1140
1141    if output.is_empty() {
1142        return results;
1143    }
1144
1145    // Probs::new expects an Array1, which we can create from the slice
1146    let mut probs_vec = output.to_vec();
1147
1148    // Check if softmax is already applied (sum ≈ 1.0)
1149    let sum: f32 = probs_vec.iter().sum();
1150    if (sum - 1.0).abs() > 0.1 && sum > 0.0 {
1151        // Apply softmax normalization
1152        let max_val = probs_vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1153        let exp_vals: Vec<f32> = probs_vec.iter().map(|&v| (v - max_val).exp()).collect();
1154        let exp_sum: f32 = exp_vals.iter().sum();
1155        if exp_sum > 0.0 {
1156            probs_vec = exp_vals.iter().map(|&v| v / exp_sum).collect();
1157        }
1158    }
1159
1160    let probs = ndarray::Array1::from_vec(probs_vec);
1161    results.probs = Some(Probs::new(probs));
1162
1163    results
1164}
1165
1166/// Post-process OBB (oriented bounding box) model output.
1167///
1168/// Extracts oriented bounding boxes with rotation angle.
1169///
1170/// # Arguments
1171///
1172/// * `output` - Model output data.
1173/// * `output_shape` - Output tensor shape.
1174/// * `preprocess` - Preprocessing metadata.
1175/// * `config` - Inference configuration.
1176/// * `names` - Class name mapping.
1177/// * `orig_img` - Original image.
1178/// * `path` - Source path.
1179/// * `speed` - Timing metrics.
1180/// * `inference_shape` - Inference dimensions.
1181///
1182/// # Returns
1183///
1184/// `Results` struct containing oriented bounding boxes.
1185#[allow(
1186    clippy::too_many_arguments,
1187    clippy::too_many_lines,
1188    clippy::similar_names
1189)]
1190fn postprocess_obb(
1191    output: &[f32],
1192    output_shape: &[usize],
1193    preprocess: &PreprocessResult,
1194    config: &InferenceConfig,
1195    names: &HashMap<usize, String>,
1196    orig_img: Array3<u8>,
1197    path: String,
1198    speed: Speed,
1199    inference_shape: (u32, u32),
1200) -> Results {
1201    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1202
1203    // OBB format: [xywh, class_scores..., rotation_angle]
1204    // features = 4 (bbox) + num_classes + 1 (angle)
1205    let num_classes = names.len().max(1);
1206    let expected_features = 4 + num_classes + 1;
1207
1208    // Parse output shape
1209    let (num_preds, is_transposed) = if output_shape.len() == 3 {
1210        let (a, b) = (output_shape[1], output_shape[2]);
1211        if a == expected_features || (a < b && a >= 6) {
1212            (b, false) // [1, features, preds]
1213        } else {
1214            (a, true) // [1, preds, features]
1215        }
1216    } else if output_shape.len() == 2 {
1217        let (a, b) = (output_shape[0], output_shape[1]);
1218        if a < b { (b, false) } else { (a, true) }
1219    } else {
1220        (0, false)
1221    };
1222
1223    if output.is_empty() || num_preds == 0 {
1224        return results;
1225    }
1226
1227    // Infer actual feature count from data
1228    let actual_features = output.len() / num_preds;
1229    if actual_features < 6 {
1230        eprintln!(
1231            "WARNING ⚠️ OBB model has insufficient features ({actual_features}), expected at least 6"
1232        );
1233        return results;
1234    }
1235
1236    // Convert to 2D [preds, features]
1237    let output_2d = if is_transposed {
1238        Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
1239            .unwrap_or_else(|_| Array2::zeros((0, 0)))
1240    } else {
1241        let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
1242            .unwrap_or_else(|_| Array2::zeros((0, 0)));
1243        arr.t().to_owned()
1244    };
1245
1246    if output_2d.is_empty() {
1247        return results;
1248    }
1249
1250    // Derive number of classes from features: features = 4 + nc + 1
1251    let derived_classes = actual_features.saturating_sub(5); // 4 bbox + 1 angle
1252    let num_classes = derived_classes.max(1);
1253
1254    // Filter and NMS - store candidates with angle
1255    let mut candidates: Vec<([f32; 5], f32, usize)> = Vec::new(); // [cx, cy, w, h, angle], conf, class
1256
1257    for i in 0..num_preds {
1258        // Get class scores
1259        let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
1260        let (best_class, best_score) = class_scores
1261            .iter()
1262            .enumerate()
1263            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
1264            .map_or((0, 0.0), |(idx, &score)| {
1265                (idx, if score.is_nan() { 0.0 } else { score })
1266            });
1267
1268        if best_score < config.confidence_threshold {
1269            continue;
1270        }
1271
1272        // Extract OBB: xywh + rotation
1273        let cx = output_2d[[i, 0]];
1274        let cy = output_2d[[i, 1]];
1275        let w = output_2d[[i, 2]];
1276        let h = output_2d[[i, 3]];
1277        let angle = output_2d[[i, 4 + num_classes]]; // Last value is rotation angle (radians)
1278
1279        // Scale center coordinates to original image space
1280        let scaled = scale_coords(&[cx, cy, cx, cy], preprocess.scale, preprocess.padding);
1281        let scaled_cx = scaled[0];
1282        let scaled_cy = scaled[1];
1283
1284        // Scale width and height (note: don't apply padding, just scale)
1285        let scaled_w = w / preprocess.scale.1;
1286        let scaled_h = h / preprocess.scale.0;
1287
1288        // Clip center to image bounds
1289        let (oh, ow) = preprocess.orig_shape;
1290        #[allow(clippy::cast_precision_loss)]
1291        let clipped_cx = scaled_cx.max(0.0).min(ow as f32);
1292        #[allow(clippy::cast_precision_loss)]
1293        let clipped_cy = scaled_cy.max(0.0).min(oh as f32);
1294
1295        // Filter by class if specified
1296        if !config.keep_class(best_class) {
1297            continue;
1298        }
1299
1300        candidates.push((
1301            [clipped_cx, clipped_cy, scaled_w, scaled_h, angle],
1302            best_score,
1303            best_class,
1304        ));
1305    }
1306
1307    if candidates.is_empty() {
1308        results.obb = Some(Obb::new(Array2::zeros((0, 7)), preprocess.orig_shape));
1309        return results;
1310    }
1311
1312    // Apply Rotated NMS for precise suppression using ProbIoU (Hellinger distance).
1313    // This ensures that overlapping rotated boxes are correctly filtered based on their
1314    // actual geometric overlap, which standard axis-aligned IoU cannot handle.
1315    let keep_indices = nms_rotated_per_class(&candidates, config.iou_threshold);
1316    let num_kept = keep_indices.len().min(config.max_det);
1317
1318    // Build output array: [cx, cy, w, h, rotation, conf, cls]
1319    let mut obb_data = Array2::zeros((num_kept, 7));
1320
1321    for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
1322        let (xywhr, score, class) = &candidates[keep_idx];
1323        obb_data[[out_idx, 0]] = xywhr[0]; // cx
1324        obb_data[[out_idx, 1]] = xywhr[1]; // cy
1325        obb_data[[out_idx, 2]] = xywhr[2]; // w
1326        obb_data[[out_idx, 3]] = xywhr[3]; // h
1327        obb_data[[out_idx, 4]] = xywhr[4]; // rotation (radians)
1328        obb_data[[out_idx, 5]] = *score;
1329        #[allow(clippy::cast_precision_loss)]
1330        let class_f32 = *class as f32;
1331        obb_data[[out_idx, 6]] = class_f32;
1332    }
1333
1334    results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
1335
1336    results
1337}
1338
1339// -------- YOLO26 end-to-end (NMS-free) postprocessing --------
1340//
1341// End-to-end exports bake NMS into the graph and produce a tensor of shape
1342// `[B, max_det, 6 + extra]`, where the first 6 columns are always
1343// `[x1, y1, x2, y2, conf, cls]` (OBB is the exception; see `postprocess_obb_end2end`).
1344// Coordinates are in the letterboxed model-input space, so we still scale/pad-correct
1345// them back to original-image space. No NMS, no class-score matrix.
1346
1347/// Helper: scale a model-space xyxy box to original image coordinates.
1348#[inline]
1349fn scale_xyxy(
1350    x1: f32,
1351    y1: f32,
1352    x2: f32,
1353    y2: f32,
1354    preprocess: &PreprocessResult,
1355) -> (f32, f32, f32, f32) {
1356    let (scale_y, scale_x) = preprocess.scale;
1357    let (pad_top, pad_left) = preprocess.padding;
1358    (
1359        (x1 - pad_left) / scale_x,
1360        (y1 - pad_top) / scale_y,
1361        (x2 - pad_left) / scale_x,
1362        (y2 - pad_top) / scale_y,
1363    )
1364}
1365
1366/// Post-process YOLO26 end-to-end detection output `[1, max_det, 6]`.
1367#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
1368fn postprocess_detect_end2end(
1369    output: &[f32],
1370    output_shape: &[usize],
1371    preprocess: &PreprocessResult,
1372    config: &InferenceConfig,
1373    names: &HashMap<usize, String>,
1374    orig_img: Array3<u8>,
1375    path: String,
1376    speed: Speed,
1377    inference_shape: (u32, u32),
1378) -> Results {
1379    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1380
1381    if output_shape.len() != 3 || output.is_empty() {
1382        return results;
1383    }
1384    let max_det = output_shape[1];
1385    let feats = output_shape[2];
1386    if feats < 6 || max_det == 0 {
1387        return results;
1388    }
1389
1390    let (oh, ow) = preprocess.orig_shape;
1391    let (max_w, max_h) = (ow as f32, oh as f32);
1392    let user_cap = config.max_det.min(max_det);
1393
1394    let mut flat: Vec<f32> = Vec::with_capacity(user_cap * 6);
1395    for i in 0..max_det {
1396        let base = i * feats;
1397        let conf = output[base + 4];
1398        // End2end outputs are sorted by confidence descending; stop on first below threshold.
1399        if conf < config.confidence_threshold {
1400            break;
1401        }
1402        let cls = output[base + 5] as usize;
1403        if !config.keep_class(cls) {
1404            continue;
1405        }
1406        let (x1, y1, x2, y2) = scale_xyxy(
1407            output[base],
1408            output[base + 1],
1409            output[base + 2],
1410            output[base + 3],
1411            preprocess,
1412        );
1413        flat.extend_from_slice(&[
1414            x1.clamp(0.0, max_w),
1415            y1.clamp(0.0, max_h),
1416            x2.clamp(0.0, max_w),
1417            y2.clamp(0.0, max_h),
1418            conf,
1419            cls as f32,
1420        ]);
1421        if flat.len() >= user_cap * 6 {
1422            break;
1423        }
1424    }
1425
1426    let n = flat.len() / 6;
1427    if n > 0 {
1428        let boxes_data = Array2::from_shape_vec((n, 6), flat).expect("flat length matches (n, 6)");
1429        results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1430    }
1431    results
1432}
1433
1434/// Post-process YOLO26 end-to-end segmentation output.
1435///
1436/// `output0`: `[1, max_det, 6 + nm]`, `output1`: `[1, nm, mh, mw]` (protos).
1437#[allow(
1438    clippy::too_many_arguments,
1439    clippy::cast_precision_loss,
1440    clippy::too_many_lines,
1441    clippy::needless_pass_by_value,
1442    clippy::similar_names,
1443    clippy::manual_let_else
1444)]
1445fn postprocess_segment_end2end(
1446    outputs: Vec<(&[f32], Vec<usize>)>,
1447    preprocess: &PreprocessResult,
1448    config: &InferenceConfig,
1449    names: &HashMap<usize, String>,
1450    orig_img: Array3<u8>,
1451    path: String,
1452    speed: Speed,
1453    inference_shape: (u32, u32),
1454) -> Results {
1455    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1456    if outputs.len() < 2 {
1457        eprintln!(
1458            "WARNING ⚠️ End2end segmentation missing protos output (got {} outputs).",
1459            outputs.len()
1460        );
1461        return results;
1462    }
1463    let (output0, shape0) = &outputs[0];
1464    let (output1, shape1) = &outputs[1];
1465
1466    if shape0.len() != 3 || shape1.len() != 4 {
1467        return results;
1468    }
1469    let max_det = shape0[1];
1470    let feats = shape0[2];
1471    let num_masks = shape1[1];
1472    if feats < 6 + num_masks {
1473        eprintln!("WARNING ⚠️ End2end segment features ({feats}) < 6 + num_masks ({num_masks}).");
1474        return results;
1475    }
1476
1477    let (oh, ow) = preprocess.orig_shape;
1478    let (max_w, max_h) = (ow as f32, oh as f32);
1479    let user_cap = config.max_det.min(max_det);
1480
1481    let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
1482    let mut flat_coeffs: Vec<f32> = Vec::with_capacity(user_cap * num_masks);
1483
1484    for i in 0..max_det {
1485        let base = i * feats;
1486        let conf = output0[base + 4];
1487        if conf < config.confidence_threshold {
1488            break;
1489        }
1490        let cls = output0[base + 5] as usize;
1491        if !config.keep_class(cls) {
1492            continue;
1493        }
1494        let (x1, y1, x2, y2) = scale_xyxy(
1495            output0[base],
1496            output0[base + 1],
1497            output0[base + 2],
1498            output0[base + 3],
1499            preprocess,
1500        );
1501        flat_boxes.extend_from_slice(&[
1502            x1.clamp(0.0, max_w),
1503            y1.clamp(0.0, max_h),
1504            x2.clamp(0.0, max_w),
1505            y2.clamp(0.0, max_h),
1506            conf,
1507            cls as f32,
1508        ]);
1509        let coeff_start = base + 6;
1510        flat_coeffs.extend_from_slice(&output0[coeff_start..coeff_start + num_masks]);
1511        if flat_boxes.len() >= user_cap * 6 {
1512            break;
1513        }
1514    }
1515
1516    let num_kept = flat_boxes.len() / 6;
1517    if num_kept == 0 {
1518        return results;
1519    }
1520
1521    let boxes_data =
1522        Array2::from_shape_vec((num_kept, 6), flat_boxes).expect("flat length matches (n, 6)");
1523    let mask_coeffs = Array2::from_shape_vec((num_kept, num_masks), flat_coeffs)
1524        .expect("flat length matches (n, num_masks)");
1525
1526    // Protos -> masks (mirrors standard segment path).
1527    let mh = shape1[2];
1528    let mw = shape1[3];
1529    let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
1530        Ok(a) => a,
1531        Err(e) => {
1532            eprintln!("WARNING ⚠️ Failed to build protos array: {e}. Skipping masks.");
1533            return results;
1534        }
1535    };
1536    let masks_flat = mask_coeffs.dot(&protos);
1537
1538    let (th, tw) = inference_shape;
1539    let (pad_top, pad_left) = preprocess.padding;
1540    let scale_w = mw as f32 / tw as f32;
1541    let scale_h = mh as f32 / th as f32;
1542    let crop_x = pad_left * scale_w;
1543    let crop_y = pad_top * scale_h;
1544    let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
1545    let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
1546
1547    let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
1548    Zip::from(masks_data.outer_iter_mut())
1549        .and(masks_flat.outer_iter())
1550        .and(boxes_data.outer_iter())
1551        .par_for_each(
1552            |mut mask_out: ArrayViewMut2<f32>,
1553             mask_flat: ArrayView1<f32>,
1554             box_data: ArrayView1<f32>| {
1555                let mut resizer = Resizer::new();
1556                let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
1557                let f32_data: Vec<f32> = mask_flat
1558                    .iter()
1559                    .map(|&v| 1.0 / (1.0 + (-v).exp()))
1560                    .collect();
1561                let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
1562                let src_image = match Image::from_vec_u8(
1563                    mw as u32,
1564                    mh as u32,
1565                    src_bytes.to_vec(),
1566                    PixelType::F32,
1567                ) {
1568                    Ok(i) => i,
1569                    Err(_) => return,
1570                };
1571                let mut dst_image = Image::new(ow, oh, PixelType::F32);
1572                let options = ResizeOptions::new().resize_alg(resize_alg).crop(
1573                    f64::from(crop_x.max(0.0)),
1574                    f64::from(crop_y.max(0.0)),
1575                    f64::from(crop_w.max(1.0).min(mw as f32)),
1576                    f64::from(crop_h.max(1.0).min(mh as f32)),
1577                );
1578                if resizer
1579                    .resize(&src_image, &mut dst_image, &options)
1580                    .is_err()
1581                {
1582                    return;
1583                }
1584                let dst_bytes = dst_image.buffer();
1585                let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
1586                let x1 = box_data[0].max(0.0).min(ow as f32);
1587                let y1 = box_data[1].max(0.0).min(oh as f32);
1588                let x2 = box_data[2].max(0.0).min(ow as f32);
1589                let y2 = box_data[3].max(0.0).min(oh as f32);
1590                for y in 0..oh as usize {
1591                    for x in 0..ow as usize {
1592                        let val = dst_slice[y * ow as usize + x];
1593                        let xf = x as f32;
1594                        let yf = y as f32;
1595                        if xf >= x1 && xf <= x2 && yf >= y1 && yf <= y2 {
1596                            mask_out[[y, x]] = val;
1597                        }
1598                    }
1599                }
1600            },
1601        );
1602
1603    results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1604    results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
1605    results
1606}
1607
1608/// Post-process YOLO26 end-to-end pose output `[1, max_det, 6 + nk*kpt_dim]`.
1609#[allow(
1610    clippy::too_many_arguments,
1611    clippy::cast_precision_loss,
1612    clippy::similar_names
1613)]
1614fn postprocess_pose_end2end(
1615    output: &[f32],
1616    output_shape: &[usize],
1617    preprocess: &PreprocessResult,
1618    config: &InferenceConfig,
1619    names: &HashMap<usize, String>,
1620    orig_img: Array3<u8>,
1621    path: String,
1622    speed: Speed,
1623    inference_shape: (u32, u32),
1624    nk: usize,
1625    kpt_dim: usize,
1626) -> Results {
1627    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1628    if output_shape.len() != 3 || output.is_empty() || nk == 0 || kpt_dim < 2 {
1629        return results;
1630    }
1631    let max_det = output_shape[1];
1632    let feats = output_shape[2];
1633    if feats < 6 + nk * kpt_dim || max_det == 0 {
1634        return results;
1635    }
1636
1637    let (oh, ow) = preprocess.orig_shape;
1638    let (max_w, max_h) = (ow as f32, oh as f32);
1639    let (scale_y, scale_x) = preprocess.scale;
1640    let (pad_top, pad_left) = preprocess.padding;
1641    let user_cap = config.max_det.min(max_det);
1642
1643    let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
1644    let mut flat_kpts: Vec<f32> = Vec::with_capacity(user_cap * nk * 3);
1645
1646    for i in 0..max_det {
1647        let base = i * feats;
1648        let conf = output[base + 4];
1649        if conf < config.confidence_threshold {
1650            break;
1651        }
1652        let cls = output[base + 5] as usize;
1653        if !config.keep_class(cls) {
1654            continue;
1655        }
1656        let (x1, y1, x2, y2) = scale_xyxy(
1657            output[base],
1658            output[base + 1],
1659            output[base + 2],
1660            output[base + 3],
1661            preprocess,
1662        );
1663        flat_boxes.extend_from_slice(&[
1664            x1.clamp(0.0, max_w),
1665            y1.clamp(0.0, max_h),
1666            x2.clamp(0.0, max_w),
1667            y2.clamp(0.0, max_h),
1668            conf,
1669            cls as f32,
1670        ]);
1671        let kstart = base + 6;
1672        for k in 0..nk {
1673            let off = kstart + k * kpt_dim;
1674            let sx = (output[off] - pad_left) / scale_x;
1675            let sy = (output[off + 1] - pad_top) / scale_y;
1676            let kconf = if kpt_dim >= 3 { output[off + 2] } else { 1.0 };
1677            flat_kpts.extend_from_slice(&[sx.clamp(0.0, max_w), sy.clamp(0.0, max_h), kconf]);
1678        }
1679        if flat_boxes.len() >= user_cap * 6 {
1680            break;
1681        }
1682    }
1683
1684    let n = flat_boxes.len() / 6;
1685    // Always emit a keypoints tensor (even empty) to match the non-end2end pose path.
1686    let kdata =
1687        Array3::from_shape_vec((n, nk, 3), flat_kpts).expect("flat length matches (n, nk, 3)");
1688    results.keypoints = Some(Keypoints::new(kdata, preprocess.orig_shape));
1689    if n > 0 {
1690        let boxes_data =
1691            Array2::from_shape_vec((n, 6), flat_boxes).expect("flat length matches (n, 6)");
1692        results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1693    }
1694    results
1695}
1696
1697/// Post-process YOLO26 end-to-end OBB output `[1, max_det, 7]`
1698/// with layout `[cx, cy, w, h, conf, cls, angle]`.
1699#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
1700fn postprocess_obb_end2end(
1701    output: &[f32],
1702    output_shape: &[usize],
1703    preprocess: &PreprocessResult,
1704    config: &InferenceConfig,
1705    names: &HashMap<usize, String>,
1706    orig_img: Array3<u8>,
1707    path: String,
1708    speed: Speed,
1709    inference_shape: (u32, u32),
1710) -> Results {
1711    let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1712    let mut flat: Vec<f32> = Vec::new();
1713
1714    if output_shape.len() == 3 && !output.is_empty() {
1715        let max_det = output_shape[1];
1716        let feats = output_shape[2];
1717        if feats >= 7 && max_det > 0 {
1718            let (oh, ow) = preprocess.orig_shape;
1719            let (max_w, max_h) = (ow as f32, oh as f32);
1720            let (scale_y, scale_x) = preprocess.scale;
1721            let (pad_top, pad_left) = preprocess.padding;
1722            let user_cap = config.max_det.min(max_det);
1723            flat.reserve(user_cap * 7);
1724
1725            for i in 0..max_det {
1726                let base = i * feats;
1727                let conf = output[base + 4];
1728                if conf < config.confidence_threshold {
1729                    break;
1730                }
1731                let cls = output[base + 5] as usize;
1732                if !config.keep_class(cls) {
1733                    continue;
1734                }
1735                let cx = (output[base] - pad_left) / scale_x;
1736                let cy = (output[base + 1] - pad_top) / scale_y;
1737                flat.extend_from_slice(&[
1738                    cx.clamp(0.0, max_w),
1739                    cy.clamp(0.0, max_h),
1740                    output[base + 2] / scale_x,
1741                    output[base + 3] / scale_y,
1742                    output[base + 6],
1743                    conf,
1744                    cls as f32,
1745                ]);
1746                if flat.len() >= user_cap * 7 {
1747                    break;
1748                }
1749            }
1750        }
1751    }
1752
1753    let n = flat.len() / 7;
1754    let obb_data = Array2::from_shape_vec((n, 7), flat).expect("flat length matches (n, 7)");
1755    results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
1756    results
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761    use super::*;
1762
1763    #[test]
1764    fn test_parse_detect_shape() {
1765        // Standard YOLO output [1, 84, 8400]
1766        let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 80);
1767        assert_eq!(nc, 80);
1768        assert_eq!(np, 8400);
1769        assert!(!transposed);
1770
1771        // Transposed format [1, 8400, 84]
1772        let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 80);
1773        assert_eq!(nc, 80);
1774        assert_eq!(np, 8400);
1775        assert!(transposed);
1776    }
1777
1778    #[test]
1779    fn test_infer_end2end_kpt_shape() {
1780        // COCO pose: 17 kpts × 3 dims -> kpt_feats=51 (only div3) -> (17, 3)
1781        assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 51]), Some((17, 3)));
1782        // Pure 2D pose: 17 × 2 -> kpt_feats=34 (only div2) -> (17, 2)
1783        assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 34]), Some((17, 2)));
1784        // Ambiguous (divisible by 6): 12 × 3 vs 18 × 2 both decode to 36 -> None
1785        assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 36]), None);
1786        // Shape that isn't the end-to-end layout -> None
1787        assert_eq!(infer_end2end_kpt_shape(&[1, 56, 8400]), None);
1788        // shape[2] <= 6 -> None (no keypoint features)
1789        assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6]), None);
1790    }
1791
1792    #[test]
1793    fn test_parse_detect_shape_no_metadata() {
1794        // When metadata is missing (expected_classes == 0), infer from shape
1795        // Standard YOLO output [1, 84, 8400] with no metadata
1796        let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 0);
1797        assert_eq!(nc, 80); // Inferred: 84 - 4 = 80 classes
1798        assert_eq!(np, 8400);
1799        assert!(!transposed);
1800
1801        // Transposed format [1, 8400, 84] with no metadata
1802        let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 0);
1803        assert_eq!(nc, 80); // Inferred: 84 - 4 = 80 classes
1804        assert_eq!(np, 8400);
1805        assert!(transposed);
1806    }
1807
1808    #[test]
1809    fn test_empty_output() {
1810        let output: Vec<f32> = vec![];
1811        let preprocess = PreprocessResult {
1812            tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1813            tensor_f16: None,
1814            orig_shape: (480, 640),
1815            scale: (1.0, 1.0),
1816            padding: (0.0, 0.0),
1817        };
1818        let config = InferenceConfig::default();
1819        let names = HashMap::new();
1820        let orig_img = ndarray::Array3::zeros((480, 640, 3));
1821
1822        let results = postprocess_detect(
1823            &output,
1824            &[1, 84, 0],
1825            &preprocess,
1826            &config,
1827            &names,
1828            orig_img,
1829            String::new(),
1830            Speed::default(),
1831            (640, 640),
1832        );
1833
1834        assert!(results.is_empty());
1835    }
1836
1837    #[test]
1838    fn test_nan_scores_handled() {
1839        // Test that NaN scores don't cause panic
1840        let mut output: Vec<f32> = vec![0.0; 84]; // One prediction
1841        // Set box coords
1842        output[0] = 100.0; // cx
1843        output[1] = 100.0; // cy
1844        output[2] = 50.0; // w
1845        output[3] = 50.0; // h
1846        // Set class scores with NaN
1847        output[4] = f32::NAN;
1848        output[5] = 0.9; // This should be selected even with NaN present
1849
1850        let preprocess = PreprocessResult {
1851            tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1852            tensor_f16: None,
1853            orig_shape: (640, 640),
1854            scale: (1.0, 1.0),
1855            padding: (0.0, 0.0),
1856        };
1857        let config = InferenceConfig::default();
1858        let mut names = HashMap::new();
1859        names.insert(0, "class0".to_string());
1860        names.insert(1, "class1".to_string());
1861        let orig_img = ndarray::Array3::zeros((640, 640, 3));
1862
1863        // This should not panic
1864        let results = postprocess_detect(
1865            &output,
1866            &[1, 84, 1],
1867            &preprocess,
1868            &config,
1869            &names,
1870            orig_img,
1871            String::new(),
1872            Speed::default(),
1873            (640, 640),
1874        );
1875
1876        // Test passed if we got here without panicking - NaN was handled gracefully
1877        // Note: The detection may or may not exist depending on how NaN affects max_by
1878        // The key is that the code didn't crash
1879        let _ = results;
1880    }
1881
1882    #[test]
1883    fn test_malformed_shape_fallback() {
1884        // Test that malformed shapes return empty results instead of panicking
1885        let output: Vec<f32> = vec![0.0; 100]; // Some data
1886
1887        let preprocess = PreprocessResult {
1888            tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1889            tensor_f16: None,
1890            orig_shape: (640, 640),
1891            scale: (1.0, 1.0),
1892            padding: (0.0, 0.0),
1893        };
1894        let config = InferenceConfig::default();
1895        let names = HashMap::new();
1896        let orig_img = ndarray::Array3::zeros((640, 640, 3));
1897
1898        // Empty shape should not panic
1899        let results = postprocess_detect(
1900            &output,
1901            &[],
1902            &preprocess,
1903            &config,
1904            &names,
1905            orig_img.clone(),
1906            String::new(),
1907            Speed::default(),
1908            (640, 640),
1909        );
1910        assert!(results.is_empty());
1911
1912        // Single dimension shape should not panic
1913        let results = postprocess_detect(
1914            &output,
1915            &[100],
1916            &preprocess,
1917            &config,
1918            &names,
1919            orig_img,
1920            String::new(),
1921            Speed::default(),
1922            (640, 640),
1923        );
1924        assert!(results.is_empty());
1925    }
1926
1927    #[test]
1928    fn test_postprocess_pose_logic() {
1929        // Mock output for pose: [1, 56, 100]
1930        // 56 features = 4 bbox + 1 class + 51 keypoints (17*3)
1931        let num_preds = 100;
1932        let num_features = 56;
1933        let mut output = vec![0.0; num_preds * num_features];
1934
1935        // Fill one prediction
1936        let idx = 0;
1937        // BBox: cx, cy, w, h
1938        output[idx] = 100.0;
1939        output[idx + num_preds] = 100.0;
1940        output[idx + num_preds * 2] = 50.0;
1941        output[idx + num_preds * 3] = 50.0;
1942        // Class score
1943        output[idx + num_preds * 4] = 0.9;
1944        // Keypoints: 17 * 3
1945        for k in 0..17 {
1946            let offset = 5 + k * 3;
1947            output[idx + num_preds * offset] = 100.0; // x
1948            output[idx + num_preds * (offset + 1)] = 100.0; // y
1949            output[idx + num_preds * (offset + 2)] = 0.8; // conf
1950        }
1951
1952        let preprocess = PreprocessResult {
1953            tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1954            tensor_f16: None,
1955            orig_shape: (640, 640),
1956            scale: (1.0, 1.0),
1957            padding: (0.0, 0.0),
1958        };
1959        let config = InferenceConfig::default();
1960        let mut names = HashMap::new();
1961        names.insert(0, "person".to_string());
1962
1963        // Shape [1, 56, 100]
1964        let results = postprocess_pose(
1965            &output,
1966            &[1, num_features, num_preds],
1967            &preprocess,
1968            &config,
1969            &names,
1970            ndarray::Array3::zeros((640, 640, 3)),
1971            "test.jpg".to_string(),
1972            Speed::default(),
1973            (640, 640),
1974        );
1975
1976        assert!(results.keypoints.is_some());
1977        let kpts = results.keypoints.unwrap();
1978        assert_eq!(kpts.data.shape()[0], 1); // 1 detection
1979        assert_eq!(kpts.data.shape()[1], 17); // 17 keypoints
1980        assert_eq!(kpts.data.shape()[2], 3); // x, y, conf
1981
1982        #[allow(clippy::float_cmp)]
1983        {
1984            // Verify values
1985            assert_eq!(kpts.data[[0, 0, 0]], 100.0);
1986            assert_eq!(kpts.data[[0, 0, 2]], 0.8);
1987        }
1988    }
1989
1990    #[test]
1991    fn test_postprocess_obb_logic() {
1992        // Mock output for OBB: [1, 6, 100]
1993        // 6 features = 4 bbox + 1 class + 1 angle
1994        let num_preds = 100;
1995        let num_features = 6;
1996        let mut output = vec![0.0; num_preds * num_features];
1997
1998        // Fill one prediction
1999        let idx = 0;
2000        // BBox: cx, cy, w, h
2001        output[idx] = 100.0;
2002        output[idx + num_preds] = 100.0;
2003        output[idx + num_preds * 2] = 50.0;
2004        output[idx + num_preds * 3] = 20.0;
2005        // Class score
2006        output[idx + num_preds * 4] = 0.95;
2007        // Angle
2008        output[idx + num_preds * 5] = std::f32::consts::FRAC_PI_4; // 45 degrees
2009
2010        let preprocess = PreprocessResult {
2011            tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
2012            tensor_f16: None,
2013            orig_shape: (640, 640),
2014            scale: (1.0, 1.0),
2015            padding: (0.0, 0.0),
2016        };
2017        let config = InferenceConfig::default();
2018        let mut names = HashMap::new();
2019        names.insert(0, "object".to_string());
2020
2021        // Shape [1, 6, 100]
2022        let results = postprocess_obb(
2023            &output,
2024            &[1, num_features, num_preds],
2025            &preprocess,
2026            &config,
2027            &names,
2028            ndarray::Array3::zeros((640, 640, 3)),
2029            "test.jpg".to_string(),
2030            Speed::default(),
2031            (640, 640),
2032        );
2033
2034        assert!(results.obb.is_some());
2035        let obb = results.obb.unwrap();
2036        assert_eq!(obb.len(), 1);
2037
2038        // Verify values
2039        let data = obb.data.row(0);
2040        #[allow(clippy::float_cmp)]
2041        {
2042            assert_eq!(data[0], 100.0); // cx
2043            assert_eq!(data[4], std::f32::consts::FRAC_PI_4); // angle
2044            assert_eq!(data[5], 0.95); // conf
2045        }
2046    }
2047}