Skip to main content

ultralytics_inference/
preprocessing.rs

1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Image preprocessing for YOLO inference.
4//!
5//! This module handles all image preprocessing operations needed before
6//! running YOLO model inference, including resizing, padding, and normalization.
7
8#![allow(
9    unsafe_code,
10    clippy::similar_names,
11    clippy::cast_precision_loss,
12    clippy::cast_possible_wrap,
13    clippy::cast_sign_loss,
14    clippy::cast_possible_truncation,
15    clippy::too_many_arguments,
16    clippy::too_many_lines,
17    clippy::wildcard_imports,
18    clippy::ptr_as_ptr,
19    clippy::cast_lossless,
20    clippy::single_match_else,
21    clippy::suboptimal_flops,
22    clippy::manual_div_ceil
23)]
24
25use std::cell::RefCell;
26use std::num::NonZeroUsize;
27
28use half::f16;
29use image::{DynamicImage, GenericImageView, RgbImage};
30use lru::LruCache;
31use ndarray::{Array3, Array4};
32
33// ================================================================================================
34// Constants
35// ================================================================================================
36
37/// Default letterbox padding color (gray).
38pub const LETTERBOX_COLOR: [u8; 3] = [114, 114, 114];
39
40/// Fixed-point scale factor for bilinear interpolation (2^11 = 2048).
41/// Matches `OpenCV`'s `INTER_RESIZE_COEF_BITS = 11` for `INTER_LINEAR`.
42const SCALE_BITS: i32 = 11;
43const SCALE_INT: i32 = 1 << SCALE_BITS;
44
45/// Double scale bits for single-pass bilinear interpolation.
46const SCALE_BITS_2X: i32 = 2 * SCALE_BITS;
47
48/// Rounding bias for fixed-point bilinear, added before the final right-shift
49/// to achieve round-to-nearest behavior matching `OpenCV`'s `saturate_cast`.
50const ROUND_BIAS: i32 = 1 << (SCALE_BITS_2X - 1);
51
52/// Normalized letterbox padding color (114/255 ≈ 0.447).
53const LETTERBOX_NORM: f32 = 114.0 / 255.0;
54
55/// Reciprocal of 255 for normalization.
56const INV_255: f32 = 1.0 / 255.0;
57
58/// Maximum LRU cache size for X coordinate LUTs.
59const LUT_CACHE_SIZE: usize = 8;
60
61// ================================================================================================
62// Type Aliases
63// ================================================================================================
64
65/// X LUT entry: (`x0_byte_offset`, `x1_byte_offset`, 1-fx, fx) using 11-bit fixed-point
66/// weights matching `OpenCV`'s `INTER_LINEAR` coordinate mapping.
67type XLutEntry = (usize, usize, i32, i32);
68type XLutKey = (u32, u32);
69
70// ================================================================================================
71// Thread-Local State
72// ================================================================================================
73
74thread_local! {
75    static X_LUT_CACHE: RefCell<LruCache<XLutKey, Vec<XLutEntry>>> =
76        RefCell::new(LruCache::new(NonZeroUsize::new(LUT_CACHE_SIZE).unwrap()));
77}
78
79// ================================================================================================
80// Types
81// ================================================================================================
82
83/// Result of preprocessing an image, containing the tensor and transform info.
84#[derive(Debug, Clone)]
85pub struct PreprocessResult {
86    /// Preprocessed image tensor in NCHW format, normalized to [0, 1].
87    pub tensor: Array4<f32>,
88    /// Preprocessed FP16 tensor (if requested).
89    pub tensor_f16: Option<Array4<f16>>,
90    /// Original image dimensions (height, width).
91    pub orig_shape: (u32, u32),
92    /// Scale factors applied (`scale_y`, `scale_x`).
93    pub scale: (f32, f32),
94    /// Padding applied (`pad_top`, `pad_left`).
95    pub padding: (f32, f32),
96}
97
98/// Preprocess an image for YOLO inference.
99///
100/// Performs letterbox resizing, BGR to RGB conversion (if needed),
101/// normalization to [0, 1], and conversion to NCHW tensor format.
102///
103/// # Arguments
104///
105/// * `image` - Input image.
106/// * `target_size` - Target size as (height, width).
107/// * `stride` - Model stride for padding alignment (typically 32).
108///
109/// # Returns
110///
111/// Preprocessed tensor and transform information for post-processing.
112#[must_use]
113pub fn preprocess_image(
114    image: &DynamicImage,
115    target_size: (usize, usize),
116    stride: u32,
117) -> PreprocessResult {
118    preprocess_image_with_precision(image, target_size, stride, false)
119}
120
121/// Preprocess an image for YOLO inference with optional FP16 output.
122///
123/// # Arguments
124///
125/// * `image` - Input image.
126/// * `target_size` - Target size as (height, width).
127/// * `stride` - Model stride for padding alignment (typically 32).
128/// * `half` - If true, also generate FP16 tensor for FP16 models.
129///
130/// # Returns
131///
132/// Preprocessed tensor and transform information for post-processing.
133#[must_use]
134pub fn preprocess_image_with_precision(
135    image: &DynamicImage,
136    target_size: (usize, usize),
137    stride: u32,
138    half: bool,
139) -> PreprocessResult {
140    let (orig_width, orig_height) = image.dimensions();
141    let orig_shape = (orig_height, orig_width);
142
143    // Calculate letterbox dimensions
144    let (new_width, new_height, pad_left, pad_top, scale) =
145        calculate_letterbox_params(orig_width, orig_height, target_size, stride);
146
147    // Zero-copy path: avoid to_rgb8() allocation when possible
148    let tensor = match image {
149        // Fast path: already RGB8, use bytes directly without copy
150        DynamicImage::ImageRgb8(rgb) => fused_zerocopy_preprocess(
151            rgb.as_raw(),
152            orig_width,
153            orig_height,
154            target_size,
155            pad_top,
156            pad_left,
157            new_width,
158            new_height,
159        ),
160        // Fallback: convert to RGB8 (allocates)
161        _ => {
162            let src_rgb = image.to_rgb8();
163            fused_zerocopy_preprocess(
164                src_rgb.as_raw(),
165                orig_width,
166                orig_height,
167                target_size,
168                pad_top,
169                pad_left,
170                new_width,
171                new_height,
172            )
173        }
174    };
175
176    let tensor_f16 = if half {
177        Some(tensor_f32_to_f16(&tensor))
178    } else {
179        None
180    };
181
182    PreprocessResult {
183        tensor,
184        tensor_f16,
185        orig_shape,
186        scale,
187        #[allow(clippy::cast_precision_loss)]
188        padding: (pad_top as f32, pad_left as f32),
189    }
190}
191
192// ================================================================================================
193// Public API Functions
194// ================================================================================================
195
196/// Get or compute the X coordinate LUT for bilinear interpolation.
197///
198/// Uses 11-bit fixed-point weights matching `OpenCV`'s `INTER_LINEAR` coordinate mapping:
199/// `src_x = (dst_x + 0.5) * (src_w / dst_w) - 0.5`
200///
201/// Weight computation matches `OpenCV`'s `resize.cpp`:
202/// `cbuf[0] = saturate_cast<short>((1-fx) * 2048); cbuf[1] = 2048 - cbuf[0];`
203fn get_or_compute_x_lut(src_w: u32, dst_w: u32) -> Vec<XLutEntry> {
204    let key = (src_w, dst_w);
205
206    X_LUT_CACHE.with(|cache| {
207        let mut cache = cache.borrow_mut();
208
209        if let Some(lut) = cache.get(&key) {
210            return lut.clone();
211        }
212
213        let scale_x = src_w as f32 / dst_w as f32;
214        let src_w_max = (src_w - 1) as i32;
215
216        let lut: Vec<XLutEntry> = (0..dst_w)
217            .map(|dx| {
218                let sx = ((dx as f32 + 0.5) * scale_x - 0.5).max(0.0);
219                let x0 = sx.floor() as i32;
220                // Match OpenCV: cbuf[0] = saturate_cast<short>((1-fx)*SCALE),
221                //               cbuf[1] = SCALE - cbuf[0]
222                let fx_f = sx - x0 as f32;
223                let fx_inv = ((1.0 - fx_f) * SCALE_INT as f32 + 0.5) as i32;
224                let fx = SCALE_INT - fx_inv;
225                let x0c = x0.clamp(0, src_w_max) as usize * 3;
226                let x1c = (x0 + 1).clamp(0, src_w_max) as usize * 3;
227                (x0c, x1c, fx_inv, fx)
228            })
229            .collect();
230
231        cache.put(key, lut.clone());
232        lut
233    })
234}
235
236/// Zero-copy fused preprocessing for maximum performance.
237///
238/// Combines bilinear resize, letterbox padding, and NCHW normalization
239/// in a single memory pass with parallel row processing.
240fn fused_zerocopy_preprocess(
241    src_raw: &[u8],
242    src_w: u32,
243    src_h: u32,
244    target_size: (usize, usize),
245    pad_top: u32,
246    pad_left: u32,
247    new_width: u32,
248    new_height: u32,
249) -> Array4<f32> {
250    use rayon::prelude::*;
251    use std::mem::MaybeUninit;
252    use std::sync::atomic::{AtomicPtr, Ordering};
253
254    let (dst_h, dst_w) = target_size;
255    let channel_size = dst_h * dst_w;
256    let src_stride = (src_w * 3) as usize;
257
258    // ALLOCATE UNINITIALIZED: Saves ~0.2ms by not zeroing memory
259    let mut tensor: Array4<MaybeUninit<f32>> = Array4::uninit((1, 3, dst_h, dst_w));
260    let out_ptr = tensor.as_mut_ptr() as *mut f32;
261
262    // Use AtomicPtr for thread-safe pointer sharing (each thread writes to disjoint rows)
263    let atomic_ptr = AtomicPtr::new(out_ptr);
264
265    let x_lut = get_or_compute_x_lut(src_w, new_width);
266    let scale_y = src_h as f32 / new_height as f32;
267    let src_h_max = (src_h - 1) as i32;
268
269    let pad_top_usize = pad_top as usize;
270    let pad_left_usize = pad_left as usize;
271    let new_height_usize = new_height as usize;
272    let new_width_usize = new_width as usize;
273
274    // Parallel row processing with raw pointers (no bounds checks)
275    (0..dst_h).into_par_iter().for_each(|dy| {
276        let data_ptr = atomic_ptr.load(Ordering::Relaxed);
277        unsafe {
278            // Calculate row pointers for R, G, B channels
279
280            let r_row = data_ptr.add(dy * dst_w);
281            let g_row = data_ptr.add(channel_size + dy * dst_w);
282            let b_row = data_ptr.add(2 * channel_size + dy * dst_w);
283
284            // Vertical padding (top/bottom rows)
285            if dy < pad_top_usize || dy >= pad_top_usize + new_height_usize {
286                for dx in 0..dst_w {
287                    *r_row.add(dx) = LETTERBOX_NORM;
288                    *g_row.add(dx) = LETTERBOX_NORM;
289                    *b_row.add(dx) = LETTERBOX_NORM;
290                }
291                return;
292            }
293
294            // Image row calculations - 11-bit fixed-point bilinear matching
295            // OpenCV's INTER_LINEAR (INTER_RESIZE_COEF_BITS = 11).
296            let img_dy = dy - pad_top_usize;
297            let sy = ((img_dy as f32 + 0.5) * scale_y - 0.5).max(0.0);
298            let y0 = sy.floor() as i32;
299            let fy_f = sy - y0 as f32;
300            let fy_inv = ((1.0 - fy_f) * SCALE_INT as f32 + 0.5) as i32;
301            let fy = SCALE_INT - fy_inv;
302
303            let y0c = y0.clamp(0, src_h_max) as usize;
304            let y1c = (y0 + 1).clamp(0, src_h_max) as usize;
305            let row0_off = y0c * src_stride;
306            let row1_off = y1c * src_stride;
307
308            // Left padding
309            for dx in 0..pad_left_usize {
310                *r_row.add(dx) = LETTERBOX_NORM;
311                *g_row.add(dx) = LETTERBOX_NORM;
312                *b_row.add(dx) = LETTERBOX_NORM;
313            }
314
315            // Inner image pixels - fixed-point bilinear with rounding.
316            // Uses untruncated weights (w = fx * fy, range [0, 2048^2]) and a
317            // single shift with rounding bias, matching OpenCV's saturate_cast:
318            //   result = (sum + ROUND_BIAS) >> 22
319            // Max intermediate: 255 * 2048^2 + 2^21 ≈ 1.07B < i32::MAX.
320            let mut img_dx = 0usize;
321            let src_ptr = src_raw.as_ptr();
322
323            while img_dx < new_width_usize {
324                let (x0_off, x1_off, fx_inv, fx) = *x_lut.get_unchecked(img_dx);
325                let w00 = fx_inv * fy_inv;
326                let w10 = fx * fy_inv;
327                let w01 = fx_inv * fy;
328                let w11 = fx * fy;
329
330                let p00 = src_ptr.add(row0_off + x0_off);
331                let p10 = src_ptr.add(row0_off + x1_off);
332                let p01 = src_ptr.add(row1_off + x0_off);
333                let p11 = src_ptr.add(row1_off + x1_off);
334
335                let out_x = pad_left_usize + img_dx;
336                *r_row.add(out_x) = ((*p00 as i32 * w00
337                    + *p10 as i32 * w10
338                    + *p01 as i32 * w01
339                    + *p11 as i32 * w11
340                    + ROUND_BIAS)
341                    >> SCALE_BITS_2X) as f32
342                    * INV_255;
343                *g_row.add(out_x) = ((*p00.add(1) as i32 * w00
344                    + *p10.add(1) as i32 * w10
345                    + *p01.add(1) as i32 * w01
346                    + *p11.add(1) as i32 * w11
347                    + ROUND_BIAS)
348                    >> SCALE_BITS_2X) as f32
349                    * INV_255;
350                *b_row.add(out_x) = ((*p00.add(2) as i32 * w00
351                    + *p10.add(2) as i32 * w10
352                    + *p01.add(2) as i32 * w01
353                    + *p11.add(2) as i32 * w11
354                    + ROUND_BIAS)
355                    >> SCALE_BITS_2X) as f32
356                    * INV_255;
357
358                img_dx += 1;
359            }
360
361            // Right padding
362            for dx in (pad_left_usize + new_width_usize)..dst_w {
363                *r_row.add(dx) = LETTERBOX_NORM;
364                *g_row.add(dx) = LETTERBOX_NORM;
365                *b_row.add(dx) = LETTERBOX_NORM;
366            }
367        }
368    });
369
370    // SAFETY: All elements have been initialized
371    unsafe { tensor.assume_init() }
372}
373
374/// Convert f32 tensor to f16 tensor.
375fn tensor_f32_to_f16(tensor: &Array4<f32>) -> Array4<half::f16> {
376    tensor.mapv(half::f16::from_f32)
377}
378
379/// Calculate target size for rectangular inference mode.
380///
381/// Adjusts `target_size` such that the image's aspect ratio is preserved,
382/// and both dimensions are multiples of `stride`.
383///
384/// # Arguments
385///
386/// * `orig_width` - Original image width.
387/// * `orig_height` - Original image height.
388/// * `target_size` - Base target size (e.g. 640x640).
389/// * `stride` - Model stride for alignment.
390///
391/// # Returns
392///
393/// Adjusted target size as (height, width).
394#[must_use]
395pub fn calculate_rect_size(
396    orig_width: u32,
397    orig_height: u32,
398    target_size: (usize, usize),
399    stride: u32,
400) -> (usize, usize) {
401    let (target_h, target_w) = target_size;
402
403    #[allow(clippy::cast_precision_loss)]
404    let orig_h = orig_height as f32;
405    #[allow(clippy::cast_precision_loss)]
406    let orig_w = orig_width as f32;
407    #[allow(clippy::cast_precision_loss)]
408    let target_h_f = target_h as f32;
409    #[allow(clippy::cast_precision_loss)]
410    let target_w_f = target_w as f32;
411
412    // Calculate scale to fit within target while maintaining aspect ratio
413    let scale = (target_h_f / orig_h).min(target_w_f / orig_w);
414
415    // New dimensions after scaling
416    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
417    let new_h = (orig_h * scale).round() as usize;
418    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
419    let new_w = (orig_w * scale).round() as usize;
420
421    // Round up to nearest multiple of stride
422    let stride = stride as usize;
423    let rect_h = ((new_h + stride - 1) / stride) * stride;
424    let rect_w = ((new_w + stride - 1) / stride) * stride;
425
426    (rect_h, rect_w)
427}
428
429/// Calculate letterbox parameters for resizing.
430///
431/// Computes new dimensions and padding to fit the image within the target size while maintaining aspect ratio.
432///
433/// # Arguments
434///
435/// * `orig_width` - Original image width.
436/// * `orig_height` - Original image height.
437/// * `target_size` - Target size as (height, width).
438/// * `stride` - Model stride for alignment (unused in calculation but kept for API compatibility).
439///
440/// # Returns
441///
442/// Tuple containing:
443/// 1. `new_width`: Scaled width.
444/// 2. `new_height`: Scaled height.
445/// 3. `pad_left`: Left padding.
446/// 4. `pad_top`: Top padding.
447/// 5. `(scale_y, scale_x)`: Scale factors.
448fn calculate_letterbox_params(
449    orig_width: u32,
450    orig_height: u32,
451    target_size: (usize, usize),
452    _stride: u32,
453) -> (u32, u32, u32, u32, (f32, f32)) {
454    #[allow(clippy::cast_precision_loss)]
455    let (target_h, target_w) = (target_size.0 as f32, target_size.1 as f32);
456    #[allow(clippy::cast_precision_loss)]
457    let (orig_h, orig_w) = (orig_height as f32, orig_width as f32);
458
459    // Calculate scale to fit within target while maintaining aspect ratio
460    let scale = (target_h / orig_h).min(target_w / orig_w);
461
462    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
463    let new_w = (orig_w * scale).round() as u32;
464    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
465    let new_h = (orig_h * scale).round() as u32;
466
467    #[allow(clippy::cast_possible_truncation)]
468    let pad_w = (target_size.1 as u32).saturating_sub(new_w);
469    #[allow(clippy::cast_possible_truncation)]
470    let pad_h = (target_size.0 as u32).saturating_sub(new_h);
471
472    // Center alignment: divide padding equally on both sides
473    let pad_left = pad_w / 2;
474    let pad_top = pad_h / 2;
475
476    // Use a uniform gain for coordinate back-projection.
477    // This matches Ultralytics `scale_boxes()`, which applies a single
478    // `gain = min(target_h / orig_h, target_w / orig_w)` to both axes.
479    // Per-axis gains (`new_w / orig_w`, `new_h / orig_h`) can diverge slightly
480    // after rounding `new_w`/`new_h`, leading to small box shifts and
481    // different NMS results.
482    (new_w, new_h, pad_left, pad_top, (scale, scale))
483}
484
485/// Convert an RGB image to a normalized NCHW tensor (FP32).
486///
487/// # Arguments
488///
489/// * `image` - RGB image to convert.
490///
491/// # Returns
492///
493/// Array4 with shape (1, 3, H, W) and values in [0, 1].
494fn image_to_tensor(image: &RgbImage) -> Array4<f32> {
495    let (width, height) = image.dimensions();
496    let (w, h) = (width as usize, height as usize);
497    let pixels = image.as_raw();
498
499    let mut tensor = Array4::zeros((1, 3, h, w));
500
501    // Get mutable slices for each channel for faster access
502    let (r_slice, rest) = tensor.as_slice_mut().unwrap().split_at_mut(h * w);
503    let (g_slice, b_slice) = rest.split_at_mut(h * w);
504
505    for (i, chunk) in pixels.chunks_exact(3).enumerate() {
506        r_slice[i] = f32::from(chunk[0]) / 255.0;
507        g_slice[i] = f32::from(chunk[1]) / 255.0;
508        b_slice[i] = f32::from(chunk[2]) / 255.0;
509    }
510
511    tensor
512}
513
514/// Convert an RGB image to a normalized NCHW tensor (FP16).
515///
516/// Converts directly from u8 to f16, avoiding intermediate f32 conversion.
517///
518/// # Arguments
519///
520/// * `image` - RGB image to convert.
521///
522/// # Returns
523///
524/// Array4 with shape (1, 3, H, W) and f16 values in [0, 1].
525fn image_to_tensor_f16(image: &RgbImage) -> Array4<f16> {
526    let (width, height) = image.dimensions();
527    let (w, h) = (width as usize, height as usize);
528    let pixels = image.as_raw();
529
530    let mut tensor = Array4::from_elem((1, 3, h, w), f16::ZERO);
531
532    let (r_slice, rest) = tensor.as_slice_mut().unwrap().split_at_mut(h * w);
533    let (g_slice, b_slice) = rest.split_at_mut(h * w);
534
535    // Precompute 1/255 as f16 for direct conversion
536    let scale = f16::from_f32(1.0 / 255.0);
537
538    for (i, chunk) in pixels.chunks_exact(3).enumerate() {
539        r_slice[i] = f16::from_f32(f32::from(chunk[0])) * scale;
540        g_slice[i] = f16::from_f32(f32::from(chunk[1])) * scale;
541        b_slice[i] = f16::from_f32(f32::from(chunk[2])) * scale;
542    }
543
544    tensor
545}
546
547/// Convert a raw HWC u8 array to a normalized NCHW tensor.
548///
549/// # Arguments
550///
551/// * `image` - HWC array with shape (H, W, C) and u8 values.
552///
553/// # Returns
554///
555/// Array4 with shape (1, C, H, W) and values in [0, 1].
556#[must_use]
557pub fn array_to_tensor(image: &Array3<u8>) -> Array4<f32> {
558    let shape = image.shape();
559    let (height, width, channels) = (shape[0], shape[1], shape[2]);
560
561    let mut tensor = Array4::zeros((1, channels, height, width));
562
563    for y in 0..height {
564        for x in 0..width {
565            for c in 0..channels {
566                tensor[[0, c, y, x]] = f32::from(image[[y, x, c]]) / 255.0;
567            }
568        }
569    }
570
571    tensor
572}
573
574/// Convert a `DynamicImage` to an HWC ndarray.
575///
576/// # Panics
577///
578/// Panics if the array cannot be created from the image pixels (e.g. dimension mismatch).
579#[must_use]
580pub fn image_to_array(image: &DynamicImage) -> Array3<u8> {
581    let rgb = image.to_rgb8();
582    let (width, height) = rgb.dimensions();
583    let pixels = rgb.into_raw();
584
585    Array3::from_shape_vec((height as usize, width as usize, 3), pixels)
586        .expect("Failed to create array from image pixels")
587}
588
589/// Scale coordinates from model output space back to original image space.
590///
591/// # Arguments
592///
593/// * `coords` - Coordinates in model space (after letterbox).
594/// * `scale` - Scale factors (`scale_y`, `scale_x`) from preprocessing.
595/// * `padding` - Padding (`pad_top`, `pad_left`) from preprocessing.
596///
597/// # Returns
598///
599/// Coordinates in original image space.
600#[must_use]
601pub fn scale_coords(coords: &[f32; 4], scale: (f32, f32), padding: (f32, f32)) -> [f32; 4] {
602    let (scale_y, scale_x) = scale;
603    let (pad_top, pad_left) = padding;
604
605    [
606        (coords[0] - pad_left) / scale_x, // x1
607        (coords[1] - pad_top) / scale_y,  // y1
608        (coords[2] - pad_left) / scale_x, // x2
609        (coords[3] - pad_top) / scale_y,  // y2
610    ]
611}
612
613/// Clip coordinates to image bounds.
614///
615/// # Arguments
616///
617/// * `coords` - Box coordinates [x1, y1, x2, y2].
618/// * `shape` - Image shape (height, width).
619///
620/// # Returns
621///
622/// Clipped coordinates.
623#[must_use]
624pub const fn clip_coords(coords: &[f32; 4], shape: (u32, u32)) -> [f32; 4] {
625    #[allow(clippy::cast_precision_loss)]
626    let (h, w) = (shape.0 as f32, shape.1 as f32);
627    [
628        coords[0].clamp(0.0, w),
629        coords[1].clamp(0.0, h),
630        coords[2].clamp(0.0, w),
631        coords[3].clamp(0.0, h),
632    ]
633}
634
635/// Preprocess an image for YOLO classification (Center Crop).
636///
637/// Resizes the image so the shortest side matches `target_size`, then center crops.
638///
639/// # Arguments
640///
641/// * `image` - Input image.
642/// * `target_size` - Target size as (height, width).
643/// * `half` - If true, also generate FP16 tensor.
644///
645/// # Returns
646///
647/// Preprocessed tensor and transform information.
648#[must_use]
649pub fn preprocess_image_center_crop(
650    image: &DynamicImage,
651    target_size: (usize, usize),
652    half: bool,
653) -> PreprocessResult {
654    let (orig_width, orig_height) = image.dimensions();
655    let orig_shape = (orig_height, orig_width);
656
657    // Perform center crop resize
658    let (cropped, scale) = center_crop_image(image, target_size);
659
660    // Convert to normalized NCHW tensor
661    let tensor = image_to_tensor(&cropped);
662
663    // Optionally compute FP16 tensor
664    let tensor_f16 = if half {
665        Some(image_to_tensor_f16(&cropped))
666    } else {
667        None
668    };
669
670    // For classification, we don't need complex coordinate mapping back to original
671    // But we provide approximate scale/padding to satisfy strict types if needed.
672    // In classification, we rarely map bounding boxes back, so these are less critical.
673    let padding = (0.0, 0.0);
674
675    PreprocessResult {
676        tensor,
677        tensor_f16,
678        orig_shape,
679        scale,
680        padding,
681    }
682}
683
684/// Resize and center crop image.
685///
686/// Resizes the image such that the shortest side equals the target dimension,
687/// maintaining aspect ratio, then crops the center `target_size`.
688///
689/// # Arguments
690///
691/// * `image` - Source dynamic image.
692/// * `target_size` - Desired output dimensions (height, width).
693///
694/// # Returns
695///
696/// Tuple containing:
697/// 1. `cropped`: The processed `RgbImage`.
698/// 2. `scale`: Scale factors applied (same for x and y).
699#[allow(clippy::similar_names)]
700fn center_crop_image(image: &DynamicImage, target_size: (usize, usize)) -> (RgbImage, (f32, f32)) {
701    use fast_image_resize::{PixelType, ResizeAlg, ResizeOptions, Resizer, images::Image};
702
703    let (src_w, src_h) = image.dimensions();
704    #[allow(clippy::cast_possible_truncation)]
705    let (target_h, target_w) = (target_size.0 as u32, target_size.1 as u32);
706
707    // Calculate scale to "cover" the target area
708    // scale = max(target_w / src_w, target_h / src_h)
709    #[allow(clippy::cast_precision_loss)]
710    let scale_x = target_w as f32 / src_w as f32;
711    #[allow(clippy::cast_precision_loss)]
712    let scale_y = target_h as f32 / src_h as f32;
713    let scale = scale_x.max(scale_y);
714
715    let (new_w, new_h) = if scale_x >= scale_y {
716        #[allow(
717            clippy::cast_possible_truncation,
718            clippy::cast_sign_loss,
719            clippy::cast_precision_loss
720        )]
721        (target_w, (src_h as f32 * scale_x) as u32)
722    } else {
723        #[allow(
724            clippy::cast_possible_truncation,
725            clippy::cast_sign_loss,
726            clippy::cast_precision_loss
727        )]
728        ((src_w as f32 * scale_y) as u32, target_h)
729    };
730
731    // Resize first
732    let src_rgb = image.to_rgb8();
733    let src_image = Image::from_vec_u8(src_w, src_h, src_rgb.into_raw(), PixelType::U8x3)
734        .expect("Failed to create source image");
735
736    // Valid dimensions check
737    let safe_new_w = new_w.max(1);
738    let safe_new_h = new_h.max(1);
739
740    let mut dst_image = Image::new(safe_new_w, safe_new_h, PixelType::U8x3);
741
742    let mut resizer = Resizer::new();
743    let options = ResizeOptions::new().resize_alg(ResizeAlg::Convolution(
744        fast_image_resize::FilterType::Bilinear,
745    ));
746    resizer
747        .resize(&src_image, &mut dst_image, Some(&options))
748        .expect("Failed to resize image");
749
750    // Convert back to RgbImage to crop
751    let resized_buffer = dst_image.into_vec();
752    let resized_rgb = RgbImage::from_raw(safe_new_w, safe_new_h, resized_buffer)
753        .expect("Failed to create resized buffer");
754
755    // Calculate crop offsets using Banker's Rounding (to match Python round())
756    #[allow(clippy::cast_precision_loss)]
757    let crop_x_float = (new_w.saturating_sub(target_w)) as f32 / 2.0;
758    #[allow(clippy::cast_precision_loss)]
759    let crop_y_float = (new_h.saturating_sub(target_h)) as f32 / 2.0;
760
761    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
762    let crop_x = bankers_round(crop_x_float) as u32;
763    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
764    let crop_y = bankers_round(crop_y_float) as u32;
765
766    let cropped =
767        image::imageops::crop_imm(&resized_rgb, crop_x, crop_y, target_w, target_h).to_image();
768
769    (cropped, (scale, scale))
770}
771
772/// Round float to nearest integer, rounding half to even (Banker's Rounding).
773/// This matches Python's `round()` behavior.
774fn bankers_round(v: f32) -> f32 {
775    let n = v.floor();
776    let d = v - n;
777    if (d - 0.5).abs() < 1e-6 {
778        if n % 2.0 == 0.0 { n } else { n + 1.0 }
779    } else {
780        v.round()
781    }
782}
783
784#[allow(clippy::similar_names)]
785#[cfg(test)]
786mod tests {
787    use super::*;
788
789    #[test]
790    fn test_letterbox_params_square() {
791        let (new_w, new_h, pad_left, pad_top, _scale) =
792            calculate_letterbox_params(640, 640, (640, 640), 32);
793
794        assert_eq!(new_w, 640);
795        assert_eq!(new_h, 640);
796        assert_eq!(pad_left, 0);
797        assert_eq!(pad_top, 0);
798    }
799
800    #[test]
801    fn test_letterbox_params_wide() {
802        let (new_w, new_h, _, _, _) = calculate_letterbox_params(1280, 720, (640, 640), 32);
803
804        // Wide image should be scaled down with height padded
805        assert!(new_w <= 640);
806        assert!(new_h <= 640);
807    }
808
809    #[test]
810    fn test_scale_coords() {
811        let coords = [100.0, 100.0, 200.0, 200.0];
812        let scale = (1.0, 1.0);
813        let padding = (10.0, 10.0);
814
815        let scaled = scale_coords(&coords, scale, padding);
816
817        assert!((scaled[0] - 90.0).abs() < 1e-6);
818        assert!((scaled[1] - 90.0).abs() < 1e-6);
819        assert!((scaled[2] - 190.0).abs() < 1e-6);
820        assert!((scaled[3] - 190.0).abs() < 1e-6);
821    }
822
823    #[test]
824    fn test_clip_coords() {
825        let coords = [-10.0, -20.0, 700.0, 500.0];
826        let clipped = clip_coords(&coords, (480, 640));
827
828        assert!((clipped[0] - 0.0).abs() < 1e-6);
829        assert!((clipped[1] - 0.0).abs() < 1e-6);
830        assert!((clipped[2] - 640.0).abs() < 1e-6);
831        assert!((clipped[3] - 480.0).abs() < 1e-6);
832    }
833}