Skip to main content

yscv_imgproc/ops/
preprocess.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5use super::resize::resize_bilinear;
6
7/// Convert an image from `[H, W, C]` (HWC) layout to `[C, H, W]` (CHW) layout.
8pub fn hwc_to_chw(image: &Tensor) -> Result<Tensor, ImgProcError> {
9    let (h, w, c) = hwc_shape(image)?;
10    let src = image.data();
11    let mut out = vec![0.0f32; c * h * w];
12    for y in 0..h {
13        for x in 0..w {
14            let hwc_base = (y * w + x) * c;
15            for ch in 0..c {
16                out[ch * h * w + y * w + x] = src[hwc_base + ch];
17            }
18        }
19    }
20    Tensor::from_vec(vec![c, h, w], out).map_err(Into::into)
21}
22
23/// Convert an image from `[C, H, W]` (CHW) layout to `[H, W, C]` (HWC) layout.
24pub fn chw_to_hwc(image: &Tensor) -> Result<Tensor, ImgProcError> {
25    let shape = image.shape();
26    if shape.len() != 3 {
27        return Err(ImgProcError::InvalidImageShape {
28            expected_rank: 3,
29            got: shape.to_vec(),
30        });
31    }
32    let (c, h, w) = (shape[0], shape[1], shape[2]);
33    let src = image.data();
34    let mut out = vec![0.0f32; h * w * c];
35    for y in 0..h {
36        for x in 0..w {
37            let hwc_base = (y * w + x) * c;
38            for ch in 0..c {
39                out[hwc_base + ch] = src[ch * h * w + y * w + x];
40            }
41        }
42    }
43    Tensor::from_vec(vec![h, w, c], out).map_err(Into::into)
44}
45
46/// Per-channel normalization: `(pixel - mean[c]) / std[c]` for an HWC image.
47pub fn normalize_image(image: &Tensor, mean: &[f32], std: &[f32]) -> Result<Tensor, ImgProcError> {
48    // Delegate to the existing normalize function.
49    super::normalize::normalize(image, mean, std)
50}
51
52/// Center-crop an `[H, W, C]` image to `size × size`.
53///
54/// If the image is smaller than `size` in either dimension the original is returned.
55pub fn center_crop(image: &Tensor, size: usize) -> Result<Tensor, ImgProcError> {
56    let (h, w, c) = hwc_shape(image)?;
57    if h <= size && w <= size {
58        return Ok(image.clone());
59    }
60    let crop_h = size.min(h);
61    let crop_w = size.min(w);
62    let y_off = (h - crop_h) / 2;
63    let x_off = (w - crop_w) / 2;
64
65    let src = image.data();
66    let mut out = vec![0.0f32; crop_h * crop_w * c];
67    for y in 0..crop_h {
68        let src_row = ((y_off + y) * w + x_off) * c;
69        let dst_row = y * crop_w * c;
70        out[dst_row..dst_row + crop_w * c].copy_from_slice(&src[src_row..src_row + crop_w * c]);
71    }
72    Tensor::from_vec(vec![crop_h, crop_w, c], out).map_err(Into::into)
73}
74
75/// Standard ImageNet preprocessing pipeline.
76///
77/// 1. Resize shortest side to 256 (bilinear)
78/// 2. Center crop to 224×224
79/// 3. Normalize with ImageNet mean/std
80/// 4. Convert HWC → CHW
81///
82/// Input: `[H, W, 3]` float32 in `[0, 1]`.
83/// Output: `[3, 224, 224]` float32 normalized.
84pub fn imagenet_preprocess(image: &Tensor) -> Result<Tensor, ImgProcError> {
85    const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
86    const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
87
88    let (h, w, _c) = hwc_shape(image)?;
89
90    // Step 1: Resize shortest side to 256.
91    let (new_h, new_w) = if h < w {
92        (256, (256 * w + h / 2) / h)
93    } else {
94        ((256 * h + w / 2) / w, 256)
95    };
96    let resized = resize_bilinear(image, new_h, new_w)?;
97
98    // Step 2: Center crop to 224×224.
99    let cropped = center_crop(&resized, 224)?;
100
101    // Step 3: Normalize.
102    let normalized = normalize_image(&cropped, &IMAGENET_MEAN, &IMAGENET_STD)?;
103
104    // Step 4: HWC → CHW.
105    hwc_to_chw(&normalized)
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_hwc_to_chw_roundtrip() {
114        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
115        let hwc = Tensor::from_vec(vec![2, 4, 3], data.clone()).unwrap();
116
117        let chw = hwc_to_chw(&hwc).unwrap();
118        assert_eq!(chw.shape(), &[3, 2, 4]);
119
120        let back = chw_to_hwc(&chw).unwrap();
121        assert_eq!(back.shape(), &[2, 4, 3]);
122        assert_eq!(back.data(), &data[..]);
123    }
124
125    #[test]
126    fn test_center_crop_exact() {
127        let img = Tensor::from_vec(vec![10, 10, 3], vec![0.5f32; 300]).unwrap();
128        let cropped = center_crop(&img, 6).unwrap();
129        assert_eq!(cropped.shape(), &[6, 6, 3]);
130    }
131
132    #[test]
133    fn test_center_crop_smaller_than_size() {
134        let img = Tensor::from_vec(vec![4, 4, 3], vec![0.5f32; 48]).unwrap();
135        let cropped = center_crop(&img, 10).unwrap();
136        assert_eq!(cropped.shape(), &[4, 4, 3]);
137    }
138
139    #[test]
140    fn test_normalize_image() {
141        // 1×1×3 image with value 0.5 for all channels.
142        let img = Tensor::from_vec(vec![1, 1, 3], vec![0.485, 0.456, 0.406]).unwrap();
143        let norm = normalize_image(&img, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]).unwrap();
144        // After normalization all channels should be ~0.
145        for &v in norm.data() {
146            assert!(v.abs() < 1e-5, "expected ~0, got {v}");
147        }
148    }
149
150    #[test]
151    fn test_imagenet_preprocess_shape() {
152        // 300×400×3 image.
153        let img = Tensor::from_vec(vec![300, 400, 3], vec![0.5f32; 300 * 400 * 3]).unwrap();
154        let result = imagenet_preprocess(&img).unwrap();
155        assert_eq!(result.shape(), &[3, 224, 224]);
156    }
157}