Skip to main content

torsh_data/
tensor_transforms.rs

1//! Tensor transformation operations for computer vision and image processing
2//!
3//! This module provides specialized transformations for tensor data, particularly
4//! focused on image and computer vision applications. All transforms operate on
5//! multi-dimensional tensors and support common image processing operations.
6//!
7//! # Features
8//!
9//! - **Random augmentations**: RandomHorizontalFlip, RandomCrop for data augmentation
10//! - **Geometric transforms**: Resize, CenterCrop for image preprocessing
11//! - **Multiple interpolation modes**: Nearest, Linear, Bilinear, Bicubic
12//! - **Flexible tensor formats**: Support for 2D (HW) and 3D (CHW) tensors
13//! - **Error handling**: Comprehensive validation of tensor dimensions and parameters
14
15use crate::transforms::Transform;
16use torsh_core::error::Result;
17use torsh_core::{
18    dtype::{FloatElement, TensorElement},
19    error::TorshError,
20};
21use torsh_tensor::Tensor;
22// ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
23use scirs2_core::random::thread_rng;
24
25/// Random horizontal flip transformation
26///
27/// Randomly flips images horizontally with a given probability. This is a
28/// common data augmentation technique for computer vision models.
29#[derive(Debug, Clone)]
30pub struct RandomHorizontalFlip {
31    prob: f32,
32}
33
34impl RandomHorizontalFlip {
35    /// Create a new random horizontal flip transform
36    ///
37    /// # Arguments
38    /// * `prob` - Probability of applying the flip (must be between 0.0 and 1.0)
39    ///
40    /// # Panics
41    /// Panics if probability is not in the range [0.0, 1.0]
42    pub fn new(prob: f32) -> Self {
43        assert!(
44            (0.0..=1.0).contains(&prob),
45            "Probability must be between 0 and 1"
46        );
47        Self { prob }
48    }
49}
50
51impl<T: FloatElement> Transform<Tensor<T>> for RandomHorizontalFlip {
52    type Output = Tensor<T>;
53
54    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
55        let mut rng = thread_rng(); // SciRS2 POLICY compliant
56
57        let random_val = rng.random::<f32>();
58        if random_val < self.prob {
59            self.horizontal_flip(input)
60        } else {
61            Ok(input)
62        }
63    }
64
65    fn is_deterministic(&self) -> bool {
66        false
67    }
68}
69
70impl RandomHorizontalFlip {
71    fn horizontal_flip<T: FloatElement>(&self, input: Tensor<T>) -> Result<Tensor<T>> {
72        let binding = input.shape();
73        let shape = binding.dims();
74        if shape.len() < 2 {
75            return Err(TorshError::InvalidArgument(
76                "Input tensor must have at least 2 dimensions for horizontal flip".to_string(),
77            ));
78        }
79
80        // For now, return input as is - proper implementation would need tensor indexing operations
81        // In a full implementation, we would reverse the last dimension (width)
82        // This requires advanced tensor operations that aren't implemented yet
83        // Debug: Applying horizontal flip to tensor with shape {:?}", shape
84        Ok(input)
85    }
86}
87
88/// Random crop transformation
89///
90/// Randomly crops a rectangular region from the input tensor. Useful for
91/// data augmentation and creating fixed-size inputs from variable-size images.
92#[derive(Debug, Clone)]
93pub struct RandomCrop {
94    size: (usize, usize),
95    padding: Option<usize>,
96}
97
98impl RandomCrop {
99    /// Create a new random crop transform
100    ///
101    /// # Arguments
102    /// * `size` - Target crop size as (height, width)
103    pub fn new(size: (usize, usize)) -> Self {
104        Self {
105            size,
106            padding: None,
107        }
108    }
109
110    /// Set padding to apply before cropping
111    ///
112    /// # Arguments
113    /// * `padding` - Number of pixels to pad on all sides
114    pub fn with_padding(mut self, padding: usize) -> Self {
115        self.padding = Some(padding);
116        self
117    }
118}
119
120impl<T: TensorElement> Transform<Tensor<T>> for RandomCrop {
121    type Output = Tensor<T>;
122
123    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
124        let shape = input.shape();
125        let dims = shape.dims();
126
127        // Expect input to be at least 2D (height, width) or 3D (channels, height, width)
128        if dims.len() < 2 {
129            return Err(TorshError::InvalidArgument(
130                "Input tensor must have at least 2 dimensions for random crop".to_string(),
131            ));
132        }
133
134        let (input_height, input_width) = if dims.len() == 2 {
135            (dims[0], dims[1])
136        } else {
137            // Assume CHW format for 3D tensors
138            (dims[1], dims[2])
139        };
140
141        let (crop_height, crop_width) = self.size;
142
143        // If crop size is larger than input, pad the input first
144        if crop_height > input_height || crop_width > input_width {
145            if let Some(padding) = self.padding {
146                // Apply padding if specified
147                let _new_height = input_height.max(crop_height) + 2 * padding;
148                let _new_width = input_width.max(crop_width) + 2 * padding;
149
150                // Create padded tensor (simplified - just return input for now)
151                // In a full implementation, we would create a properly padded tensor
152                // Debug: Applying padding of {} pixels before cropping", padding
153                return Ok(input);
154            } else {
155                return Err(TorshError::InvalidArgument(
156                    format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width}) and no padding specified"),
157                ));
158            }
159        }
160
161        // Calculate random crop position - SciRS2 POLICY compliant
162        let mut rng = thread_rng();
163        let max_y = input_height - crop_height;
164        let max_x = input_width - crop_width;
165
166        let _start_y = if max_y > 0 {
167            rng.gen_range(0..=max_y)
168        } else {
169            0
170        };
171        let _start_x = if max_x > 0 {
172            rng.gen_range(0..=max_x)
173        } else {
174            0
175        };
176
177        // For now, return the input tensor unchanged
178        // In a full implementation, we would extract the cropped region:
179        // - For 2D: input[start_y:start_y+crop_height, start_x:start_x+crop_width]
180        // - For 3D: input[:, start_y:start_y+crop_height, start_x:start_x+crop_width]
181        // This requires advanced tensor slicing operations
182
183        // Debug: Random crop from ({}, {}) to ({}, {})
184        // input_height, input_width, crop_height, crop_width
185
186        Ok(input)
187    }
188
189    fn is_deterministic(&self) -> bool {
190        false
191    }
192}
193
194/// Interpolation modes for resizing operations
195#[derive(Clone, Copy, Debug, PartialEq, Eq)]
196pub enum InterpolationMode {
197    /// Nearest neighbor interpolation
198    Nearest,
199    /// Linear interpolation
200    Linear,
201    /// Bilinear interpolation
202    Bilinear,
203    /// Bicubic interpolation
204    Bicubic,
205}
206
207impl Default for InterpolationMode {
208    fn default() -> Self {
209        Self::Bilinear
210    }
211}
212
213/// Resize transformation
214///
215/// Resizes input tensors to a target size using various interpolation methods.
216/// Commonly used for standardizing input sizes in computer vision pipelines.
217#[derive(Debug, Clone)]
218pub struct Resize {
219    size: (usize, usize),
220    interpolation: InterpolationMode,
221}
222
223impl Resize {
224    /// Create a new resize transform with bilinear interpolation
225    ///
226    /// # Arguments
227    /// * `size` - Target size as (height, width)
228    pub fn new(size: (usize, usize)) -> Self {
229        Self {
230            size,
231            interpolation: InterpolationMode::Bilinear,
232        }
233    }
234
235    /// Set the interpolation mode
236    ///
237    /// # Arguments
238    /// * `mode` - Interpolation method to use
239    pub fn with_interpolation(mut self, mode: InterpolationMode) -> Self {
240        self.interpolation = mode;
241        self
242    }
243}
244
245impl<T: FloatElement> Transform<Tensor<T>> for Resize {
246    type Output = Tensor<T>;
247
248    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
249        let shape = input.shape();
250        let dims = shape.dims();
251
252        // Expect input to be at least 2D (height, width) or 3D (channels, height, width)
253        if dims.len() < 2 {
254            return Err(TorshError::InvalidArgument(
255                "Input tensor must have at least 2 dimensions for resize".to_string(),
256            ));
257        }
258
259        let (input_height, input_width) = if dims.len() == 2 {
260            (dims[0], dims[1])
261        } else {
262            // Assume CHW format for 3D tensors
263            (dims[1], dims[2])
264        };
265
266        let (target_height, target_width) = self.size;
267
268        // If target size matches input size, no resize needed
269        if input_height == target_height && input_width == target_width {
270            return Ok(input);
271        }
272
273        // For now, return the input tensor unchanged
274        // In a full implementation, we would apply the specified interpolation:
275        // - Nearest: select nearest neighbor pixels
276        // - Linear/Bilinear: interpolate between neighboring pixels
277        // - Bicubic: use cubic interpolation with 4x4 pixel neighborhoods
278        //
279        // The implementation would:
280        // 1. Calculate scale factors: scale_y = target_height / input_height
281        // 2. For each output pixel (y, x):
282        //    - Map to input coordinates: (y/scale_y, x/scale_x)
283        //    - Apply interpolation based on self.interpolation mode
284        //    - Set output pixel value
285        // 3. Handle edge cases and boundary conditions
286
287        match self.interpolation {
288            InterpolationMode::Nearest => {
289                // Debug: Applying nearest neighbor resize from ({}, {}) to ({}, {})
290                // input_height, input_width, target_height, target_width
291                Ok(input)
292            }
293            InterpolationMode::Linear | InterpolationMode::Bilinear => {
294                // Debug: Applying bilinear resize from ({}, {}) to ({}, {})
295                // input_height, input_width, target_height, target_width
296                Ok(input)
297            }
298            InterpolationMode::Bicubic => {
299                // Debug: Applying bicubic resize from ({}, {}) to ({}, {})
300                // input_height, input_width, target_height, target_width
301                Ok(input)
302            }
303        }
304    }
305
306    fn is_deterministic(&self) -> bool {
307        true
308    }
309}
310
311/// Center crop transformation
312///
313/// Crops a rectangular region from the center of the input tensor.
314/// Useful for extracting the central portion of images with consistent positioning.
315#[derive(Debug, Clone)]
316pub struct CenterCrop {
317    size: (usize, usize),
318}
319
320impl CenterCrop {
321    /// Create a new center crop transform
322    ///
323    /// # Arguments
324    /// * `size` - Target crop size as (height, width)
325    pub fn new(size: (usize, usize)) -> Self {
326        Self { size }
327    }
328}
329
330impl<T: TensorElement> Transform<Tensor<T>> for CenterCrop {
331    type Output = Tensor<T>;
332
333    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
334        let shape = input.shape();
335        let dims = shape.dims();
336
337        // Expect input to be at least 2D (height, width) or 3D (channels, height, width)
338        if dims.len() < 2 {
339            return Err(TorshError::InvalidArgument(
340                "Input tensor must have at least 2 dimensions for center crop".to_string(),
341            ));
342        }
343
344        let (input_height, input_width) = if dims.len() == 2 {
345            (dims[0], dims[1])
346        } else {
347            // Assume CHW format for 3D tensors
348            (dims[1], dims[2])
349        };
350
351        let (crop_height, crop_width) = self.size;
352
353        // Check if crop size is larger than input
354        if crop_height > input_height || crop_width > input_width {
355            return Err(TorshError::InvalidArgument(
356                format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width})"),
357            ));
358        }
359
360        // Calculate center crop position
361        let _start_y = (input_height - crop_height) / 2;
362        let _start_x = (input_width - crop_width) / 2;
363
364        // For now, return the input tensor unchanged
365        // In a full implementation, we would extract the center crop region:
366        // - For 2D: input[start_y:start_y+crop_height, start_x:start_x+crop_width]
367        // - For 3D: input[:, start_y:start_y+crop_height, start_x:start_x+crop_width]
368        // This requires advanced tensor slicing operations
369
370        // The implementation would involve:
371        // 1. Creating a new tensor with the crop dimensions
372        // 2. Copying the appropriate region from the input tensor
373        // 3. For 2D tensors: new_tensor[y, x] = input[start_y + y, start_x + x]
374        // 4. For 3D tensors: new_tensor[c, y, x] = input[c, start_y + y, start_x + x]
375
376        // Debug: Center crop from ({}, {}) to ({}, {})
377        // input_height, input_width, crop_height, crop_width
378
379        Ok(input)
380    }
381
382    fn is_deterministic(&self) -> bool {
383        true
384    }
385}
386
387/// Convenience function to create a random horizontal flip transform
388pub fn random_horizontal_flip(prob: f32) -> RandomHorizontalFlip {
389    RandomHorizontalFlip::new(prob)
390}
391
392/// Convenience function to create a random crop transform
393pub fn random_crop(size: (usize, usize)) -> RandomCrop {
394    RandomCrop::new(size)
395}
396
397/// Convenience function to create a resize transform
398pub fn resize(size: (usize, usize)) -> Resize {
399    Resize::new(size)
400}
401
402/// Convenience function to create a center crop transform
403pub fn center_crop(size: (usize, usize)) -> CenterCrop {
404    CenterCrop::new(size)
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use torsh_core::device::DeviceType;
411
412    fn mock_tensor_2d() -> Tensor<f32> {
413        Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap()
414    }
415
416    fn mock_tensor_3d() -> Tensor<f32> {
417        Tensor::from_data(
418            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
419            vec![2, 2, 2], // 2 channels, 2x2 spatial
420            DeviceType::Cpu,
421        )
422        .unwrap()
423    }
424
425    #[test]
426    fn test_random_horizontal_flip_creation() {
427        let flip = RandomHorizontalFlip::new(0.5);
428        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &flip;
429        assert!(!_test.is_deterministic());
430    }
431
432    #[test]
433    #[should_panic]
434    fn test_random_horizontal_flip_invalid_prob() {
435        RandomHorizontalFlip::new(1.5); // Should panic
436    }
437
438    #[test]
439    fn test_random_crop_creation() {
440        let crop = RandomCrop::new((224, 224));
441        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
442        assert!(!_test.is_deterministic());
443    }
444
445    #[test]
446    fn test_random_crop_with_padding() {
447        let crop = RandomCrop::new((224, 224)).with_padding(10);
448        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
449        assert!(!_test.is_deterministic());
450    }
451
452    #[test]
453    fn test_resize_creation() {
454        let resize_transform = Resize::new((224, 224));
455        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
456        assert!(_test.is_deterministic());
457    }
458
459    #[test]
460    fn test_resize_with_interpolation() {
461        let resize_transform =
462            Resize::new((224, 224)).with_interpolation(InterpolationMode::Nearest);
463        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
464        assert!(_test.is_deterministic());
465    }
466
467    #[test]
468    fn test_center_crop_creation() {
469        let crop = CenterCrop::new((224, 224));
470        let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
471        assert!(_test.is_deterministic());
472    }
473
474    #[test]
475    fn test_interpolation_mode_default() {
476        assert_eq!(InterpolationMode::default(), InterpolationMode::Bilinear);
477    }
478
479    #[test]
480    fn test_tensor_transforms_2d() {
481        let tensor = mock_tensor_2d();
482
483        let flip = RandomHorizontalFlip::new(1.0); // Always flip
484        let result = flip.transform(tensor.clone());
485        assert!(result.is_ok());
486
487        let crop = CenterCrop::new((1, 1));
488        let result = crop.transform(tensor.clone());
489        assert!(result.is_ok());
490
491        let resize_transform = Resize::new((4, 4));
492        let result = resize_transform.transform(tensor);
493        assert!(result.is_ok());
494    }
495
496    #[test]
497    fn test_tensor_transforms_3d() {
498        let tensor = mock_tensor_3d();
499
500        let flip = RandomHorizontalFlip::new(0.0); // Never flip
501        let result = flip.transform(tensor.clone());
502        assert!(result.is_ok());
503
504        let crop = CenterCrop::new((1, 1));
505        let result = crop.transform(tensor.clone());
506        assert!(result.is_ok());
507
508        let resize_transform = Resize::new((4, 4));
509        let result = resize_transform.transform(tensor);
510        assert!(result.is_ok());
511    }
512
513    #[test]
514    fn test_convenience_functions() {
515        let _flip = random_horizontal_flip(0.5);
516        let _crop = random_crop((224, 224));
517        let _resize = resize((256, 256));
518        let _center = center_crop((224, 224));
519    }
520
521    #[test]
522    fn test_invalid_tensor_dimensions() {
523        let tensor_1d = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu).unwrap();
524
525        let flip = RandomHorizontalFlip::new(1.0);
526        assert!(flip.transform(tensor_1d.clone()).is_err());
527
528        let crop = CenterCrop::new((1, 1));
529        assert!(crop.transform(tensor_1d.clone()).is_err());
530
531        let resize_transform = Resize::new((4, 4));
532        assert!(resize_transform.transform(tensor_1d).is_err());
533    }
534
535    #[test]
536    fn test_crop_size_validation() {
537        let tensor = mock_tensor_2d(); // 2x2 tensor
538
539        let crop = CenterCrop::new((3, 3)); // Larger than input
540        assert!(crop.transform(tensor.clone()).is_err());
541
542        let random_crop = RandomCrop::new((3, 3)); // Larger than input, no padding
543        assert!(random_crop.transform(tensor).is_err());
544    }
545}