1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5use super::resize::resize_bilinear;
6
7pub fn hwc_to_chw(image: &Tensor) -> Result<Tensor, ImgProcError> {
9 let (h, w, c) = hwc_shape(image)?;
10 let src = image.data();
11 let mut out = vec![0.0f32; c * h * w];
12 for y in 0..h {
13 for x in 0..w {
14 let hwc_base = (y * w + x) * c;
15 for ch in 0..c {
16 out[ch * h * w + y * w + x] = src[hwc_base + ch];
17 }
18 }
19 }
20 Tensor::from_vec(vec![c, h, w], out).map_err(Into::into)
21}
22
23pub fn chw_to_hwc(image: &Tensor) -> Result<Tensor, ImgProcError> {
25 let shape = image.shape();
26 if shape.len() != 3 {
27 return Err(ImgProcError::InvalidImageShape {
28 expected_rank: 3,
29 got: shape.to_vec(),
30 });
31 }
32 let (c, h, w) = (shape[0], shape[1], shape[2]);
33 let src = image.data();
34 let mut out = vec![0.0f32; h * w * c];
35 for y in 0..h {
36 for x in 0..w {
37 let hwc_base = (y * w + x) * c;
38 for ch in 0..c {
39 out[hwc_base + ch] = src[ch * h * w + y * w + x];
40 }
41 }
42 }
43 Tensor::from_vec(vec![h, w, c], out).map_err(Into::into)
44}
45
46pub fn normalize_image(image: &Tensor, mean: &[f32], std: &[f32]) -> Result<Tensor, ImgProcError> {
48 super::normalize::normalize(image, mean, std)
50}
51
52pub fn center_crop(image: &Tensor, size: usize) -> Result<Tensor, ImgProcError> {
56 let (h, w, c) = hwc_shape(image)?;
57 if h <= size && w <= size {
58 return Ok(image.clone());
59 }
60 let crop_h = size.min(h);
61 let crop_w = size.min(w);
62 let y_off = (h - crop_h) / 2;
63 let x_off = (w - crop_w) / 2;
64
65 let src = image.data();
66 let mut out = vec![0.0f32; crop_h * crop_w * c];
67 for y in 0..crop_h {
68 let src_row = ((y_off + y) * w + x_off) * c;
69 let dst_row = y * crop_w * c;
70 out[dst_row..dst_row + crop_w * c].copy_from_slice(&src[src_row..src_row + crop_w * c]);
71 }
72 Tensor::from_vec(vec![crop_h, crop_w, c], out).map_err(Into::into)
73}
74
75pub fn imagenet_preprocess(image: &Tensor) -> Result<Tensor, ImgProcError> {
85 const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
86 const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
87
88 let (h, w, _c) = hwc_shape(image)?;
89
90 let (new_h, new_w) = if h < w {
92 (256, (256 * w + h / 2) / h)
93 } else {
94 ((256 * h + w / 2) / w, 256)
95 };
96 let resized = resize_bilinear(image, new_h, new_w)?;
97
98 let cropped = center_crop(&resized, 224)?;
100
101 let normalized = normalize_image(&cropped, &IMAGENET_MEAN, &IMAGENET_STD)?;
103
104 hwc_to_chw(&normalized)
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_hwc_to_chw_roundtrip() {
114 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
115 let hwc = Tensor::from_vec(vec![2, 4, 3], data.clone()).unwrap();
116
117 let chw = hwc_to_chw(&hwc).unwrap();
118 assert_eq!(chw.shape(), &[3, 2, 4]);
119
120 let back = chw_to_hwc(&chw).unwrap();
121 assert_eq!(back.shape(), &[2, 4, 3]);
122 assert_eq!(back.data(), &data[..]);
123 }
124
125 #[test]
126 fn test_center_crop_exact() {
127 let img = Tensor::from_vec(vec![10, 10, 3], vec![0.5f32; 300]).unwrap();
128 let cropped = center_crop(&img, 6).unwrap();
129 assert_eq!(cropped.shape(), &[6, 6, 3]);
130 }
131
132 #[test]
133 fn test_center_crop_smaller_than_size() {
134 let img = Tensor::from_vec(vec![4, 4, 3], vec![0.5f32; 48]).unwrap();
135 let cropped = center_crop(&img, 10).unwrap();
136 assert_eq!(cropped.shape(), &[4, 4, 3]);
137 }
138
139 #[test]
140 fn test_normalize_image() {
141 let img = Tensor::from_vec(vec![1, 1, 3], vec![0.485, 0.456, 0.406]).unwrap();
143 let norm = normalize_image(&img, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]).unwrap();
144 for &v in norm.data() {
146 assert!(v.abs() < 1e-5, "expected ~0, got {v}");
147 }
148 }
149
150 #[test]
151 fn test_imagenet_preprocess_shape() {
152 let img = Tensor::from_vec(vec![300, 400, 3], vec![0.5f32; 300 * 400 * 3]).unwrap();
154 let result = imagenet_preprocess(&img).unwrap();
155 assert_eq!(result.shape(), &[3, 224, 224]);
156 }
157}