Skip to main content

torsh_data/vision/image/
transforms.rs

1use crate::transforms::Transform;
2#[cfg(feature = "image-support")]
3use image::{DynamicImage, GenericImageView, ImageBuffer};
4use scirs2_core::RngExt;
5use torsh_core::error::{Result, TorshError};
6use torsh_tensor::Tensor;
7
8/// Transform to convert image to tensor
9pub struct ImageToTensor;
10
11impl Transform<DynamicImage> for ImageToTensor {
12    type Output = Tensor<f32>;
13
14    fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
15        #[cfg(feature = "image-support")]
16        {
17            let rgb_image = input.to_rgb8();
18            let (width, height) = rgb_image.dimensions();
19
20            // Convert to CHW format (channels, height, width)
21            let mut data = Vec::with_capacity((width * height * 3) as usize);
22
23            // Extract channels separately
24            for channel in 0..3 {
25                for y in 0..height {
26                    for x in 0..width {
27                        let pixel = rgb_image.get_pixel(x, y);
28                        let value = pixel[channel] as f32 / 255.0;
29                        data.push(value);
30                    }
31                }
32            }
33
34            Tensor::from_data(
35                data,
36                vec![3, height as usize, width as usize],
37                torsh_core::device::DeviceType::Cpu,
38            )
39        }
40
41        #[cfg(not(feature = "image-support"))]
42        {
43            Err(TorshError::UnsupportedOperation {
44                op: "image to tensor conversion".to_string(),
45                dtype: "DynamicImage".to_string(),
46            })
47        }
48    }
49}
50
51/// Transform to convert tensor to image
52pub struct TensorToImage;
53
54impl Transform<Tensor<f32>> for TensorToImage {
55    type Output = DynamicImage;
56
57    fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
58        #[cfg(feature = "image-support")]
59        {
60            let shape = input.shape();
61            if shape.ndim() != 3 {
62                return Err(TorshError::InvalidShape(
63                    "Expected 3D tensor (C, H, W)".to_string(),
64                ));
65            }
66
67            let dims = shape.dims();
68            let (channels, height, width) = (dims[0], dims[1], dims[2]);
69
70            if channels != 3 {
71                return Err(TorshError::InvalidShape(
72                    "Expected 3 channels for RGB image".to_string(),
73                ));
74            }
75
76            let data = input.to_vec()?;
77            let mut img_data = Vec::with_capacity(width * height * 3);
78
79            // Convert from CHW to HWC format
80            for y in 0..height {
81                for x in 0..width {
82                    for c in 0..3 {
83                        let idx = c * height * width + y * width + x;
84                        let value = (data[idx] * 255.0).clamp(0.0, 255.0) as u8;
85                        img_data.push(value);
86                    }
87                }
88            }
89
90            let img_buffer = ImageBuffer::from_raw(width as u32, height as u32, img_data)
91                .ok_or_else(|| {
92                    TorshError::InvalidArgument("Failed to create image buffer".to_string())
93                })?;
94
95            Ok(DynamicImage::ImageRgb8(img_buffer))
96        }
97
98        #[cfg(not(feature = "image-support"))]
99        {
100            Err(TorshError::UnsupportedOperation {
101                op: "tensor to image conversion".to_string(),
102                dtype: "Tensor<f32>".to_string(),
103            })
104        }
105    }
106}
107
108/// Compose multiple transforms
109pub struct Compose<T> {
110    transforms: Vec<Box<dyn Transform<T, Output = T>>>,
111}
112
113impl<T: 'static> Compose<T> {
114    pub fn new() -> Self {
115        Self {
116            transforms: Vec::new(),
117        }
118    }
119
120    pub fn add_transform<F>(mut self, transform: F) -> Self
121    where
122        F: Transform<T, Output = T> + 'static,
123    {
124        self.transforms.push(Box::new(transform));
125        self
126    }
127
128    pub fn add_boxed(mut self, transform: Box<dyn Transform<T, Output = T>>) -> Self {
129        self.transforms.push(transform);
130        self
131    }
132}
133
134impl<T> Transform<T> for Compose<T> {
135    type Output = T;
136
137    fn transform(&self, mut input: T) -> Result<Self::Output> {
138        for transform in &self.transforms {
139            input = transform.transform(input)?;
140        }
141        Ok(input)
142    }
143}
144
145impl<T: 'static> Default for Compose<T> {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151/// Random horizontal flip
152pub struct RandomHorizontalFlip {
153    prob: f32,
154}
155
156impl RandomHorizontalFlip {
157    pub fn new(prob: f32) -> Self {
158        Self { prob }
159    }
160}
161
162impl Transform<DynamicImage> for RandomHorizontalFlip {
163    type Output = DynamicImage;
164
165    fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
166        #[cfg(feature = "image-support")]
167        {
168            // ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
169            #[allow(unused_imports)] // Rng trait needed for random() method
170            use scirs2_core::random::{Random, Rng};
171            let mut rng = Random::seed(0);
172            if rng.random::<f32>() < self.prob {
173                Ok(input.fliph())
174            } else {
175                Ok(input)
176            }
177        }
178
179        #[cfg(not(feature = "image-support"))]
180        {
181            Err(TorshError::UnsupportedOperation {
182                op: "random horizontal flip".to_string(),
183                dtype: "DynamicImage".to_string(),
184            })
185        }
186    }
187}
188
189/// Random vertical flip
190pub struct RandomVerticalFlip {
191    prob: f32,
192}
193
194impl RandomVerticalFlip {
195    pub fn new(prob: f32) -> Self {
196        Self { prob }
197    }
198}
199
200impl Transform<DynamicImage> for RandomVerticalFlip {
201    type Output = DynamicImage;
202
203    fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
204        #[cfg(feature = "image-support")]
205        {
206            // ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
207            #[allow(unused_imports)] // Rng trait needed for random() method
208            use scirs2_core::random::{Random, Rng};
209            let mut rng = Random::seed(0);
210            if rng.random::<f32>() < self.prob {
211                Ok(input.flipv())
212            } else {
213                Ok(input)
214            }
215        }
216
217        #[cfg(not(feature = "image-support"))]
218        {
219            Err(TorshError::UnsupportedOperation {
220                op: "random vertical flip".to_string(),
221                dtype: "DynamicImage".to_string(),
222            })
223        }
224    }
225}
226
227/// Random rotation
228pub struct RandomRotation {
229    degrees: f32,
230}
231
232impl RandomRotation {
233    pub fn new(degrees: f32) -> Self {
234        Self { degrees }
235    }
236}
237
238impl Transform<DynamicImage> for RandomRotation {
239    type Output = DynamicImage;
240
241    fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
242        #[cfg(all(feature = "image-support", feature = "imageproc"))]
243        {
244            // ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
245            #[allow(unused_imports)] // Rng trait needed for random() method
246            use scirs2_core::random::{Random, Rng};
247            let mut rng = Random::seed(0);
248            let angle_deg = rng.gen_range(-self.degrees..=self.degrees);
249            let angle_rad = angle_deg.to_radians();
250
251            // Convert to RGB8 for processing
252            let rgb_image = input.to_rgb8();
253
254            // Perform rotation using imageproc
255            // Use imageproc's rotation function
256            let rotated = imageproc::geometric_transformations::rotate_about_center(
257                &rgb_image,
258                angle_rad,
259                imageproc::geometric_transformations::Interpolation::Bilinear,
260                image::Rgb([0u8, 0u8, 0u8]), // Black background
261            );
262
263            Ok(DynamicImage::ImageRgb8(rotated))
264        }
265
266        #[cfg(all(feature = "image-support", not(feature = "imageproc")))]
267        {
268            // ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
269            #[allow(unused_imports)] // Rng trait needed for random() method
270            use scirs2_core::random::{Random, Rng};
271            let mut rng = Random::seed(0);
272            let _angle = rng.gen_range(-self.degrees..=self.degrees);
273            // imageproc not available, return input unchanged
274            // In production, you might want to log a warning here
275            Ok(input)
276        }
277
278        #[cfg(not(feature = "image-support"))]
279        {
280            Err(TorshError::UnsupportedOperation {
281                op: "random rotation".to_string(),
282                dtype: "DynamicImage".to_string(),
283            })
284        }
285    }
286}
287
288/// Common vision transforms
289pub mod transforms {
290    use super::*;
291    use crate::transforms::Transform;
292
293    /// Resize image
294    pub struct Resize {
295        size: (u32, u32),
296    }
297
298    impl Resize {
299        pub fn new(size: (u32, u32)) -> Self {
300            Self { size }
301        }
302    }
303
304    impl Transform<DynamicImage> for Resize {
305        type Output = DynamicImage;
306
307        fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
308            #[cfg(feature = "image-support")]
309            {
310                Ok(input.resize_exact(
311                    self.size.0,
312                    self.size.1,
313                    image::imageops::FilterType::Lanczos3,
314                ))
315            }
316
317            #[cfg(not(feature = "image-support"))]
318            {
319                Err(TorshError::UnsupportedOperation {
320                    op: "image resize".to_string(),
321                    dtype: "DynamicImage".to_string(),
322                })
323            }
324        }
325    }
326
327    /// Center crop image
328    pub struct CenterCrop {
329        size: (u32, u32),
330    }
331
332    impl CenterCrop {
333        pub fn new(size: (u32, u32)) -> Self {
334            Self { size }
335        }
336    }
337
338    impl Transform<DynamicImage> for CenterCrop {
339        type Output = DynamicImage;
340
341        fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
342            #[cfg(feature = "image-support")]
343            {
344                let (width, height) = input.dimensions();
345                let (crop_width, crop_height) = self.size;
346
347                if crop_width > width || crop_height > height {
348                    return Err(TorshError::InvalidArgument(
349                        "Crop size cannot be larger than image size".to_string(),
350                    ));
351                }
352
353                let x = (width - crop_width) / 2;
354                let y = (height - crop_height) / 2;
355
356                Ok(input.crop_imm(x, y, crop_width, crop_height))
357            }
358
359            #[cfg(not(feature = "image-support"))]
360            {
361                Err(TorshError::UnsupportedOperation {
362                    op: "image center crop".to_string(),
363                    dtype: "DynamicImage".to_string(),
364                })
365            }
366        }
367    }
368
369    /// Normalize image values
370    pub struct Normalize {
371        mean: [f32; 3],
372        std: [f32; 3],
373    }
374
375    impl Normalize {
376        pub fn new(mean: [f32; 3], std: [f32; 3]) -> Self {
377            Self { mean, std }
378        }
379
380        /// ImageNet normalization
381        pub fn imagenet() -> Self {
382            Self::new([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
383        }
384    }
385
386    impl Transform<Tensor<f32>> for Normalize {
387        type Output = Tensor<f32>;
388
389        fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
390            // Apply ImageNet-style normalization per channel
391            // Assumes input tensor is in CHW format (channels, height, width)
392            let shape_ref = input.shape();
393            let shape = shape_ref.dims();
394
395            if shape.len() != 3 {
396                return Err(TorshError::InvalidShape(format!(
397                    "Expected 3D tensor (C, H, W), got shape {shape:?}"
398                )));
399            }
400
401            let channels = shape[0];
402            if channels != 3 {
403                return Err(TorshError::InvalidShape(format!(
404                    "Expected 3 channels for RGB image, got {channels}"
405                )));
406            }
407
408            // Get tensor data and apply normalization
409            let mut data = input.to_vec()?;
410            let height = shape[1];
411            let width = shape[2];
412            let channel_size = height * width;
413
414            // Apply per-channel normalization: (pixel - mean) / std
415            for c in 0..3 {
416                let channel_start = c * channel_size;
417                let channel_end = channel_start + channel_size;
418                let mean = self.mean[c];
419                let std = self.std[c];
420
421                for pixel in &mut data[channel_start..channel_end] {
422                    *pixel = (*pixel - mean) / std;
423                }
424            }
425
426            // Create normalized tensor with same shape and device
427            Tensor::from_data(data, shape.to_vec(), input.device())
428        }
429    }
430}