Skip to main content

torsh_functional/
image.rs

1//! Image Processing Operations
2//!
3//! This module provides comprehensive image processing operations for computer vision
4//! and deep learning applications. All operations are designed to work with tensors
5//! in standard image formats (CHW or NCHW).
6//!
7//! # Mathematical Foundation
8//!
9//! ## Image Representation
10//!
11//! Images are represented as tensors with dimensions:
12//! - **2D**: `[H, W]` - Grayscale images
13//! - **3D**: `[C, H, W]` - Multi-channel images (RGB, etc.)
14//! - **4D**: `[N, C, H, W]` - Batched images
15//!
16//! where:
17//! - `N` = batch size
18//! - `C` = number of channels
19//! - `H` = height (rows)
20//! - `W` = width (columns)
21//!
22//! ## Interpolation Methods
23//!
24//! ### Nearest Neighbor
25//! ```text
26//! I_out(x, y) = I_in(round(x * scale_x), round(y * scale_y))
27//! ```
28//! - **Complexity**: O(1) per pixel
29//! - **Quality**: Blocky, preserves sharp edges
30//! - **Use case**: Fast resizing, pixel art
31//!
32//! ### Bilinear Interpolation
33//! ```text
34//! I_out(x, y) = Σᵢ Σⱼ w(i,j) * I_in(x+i, y+j)
35//!
36//! where w(i,j) = (1 - dx) * (1 - dy)  for i=0, j=0
37//!              = dx * (1 - dy)        for i=1, j=0
38//!              = (1 - dx) * dy        for i=0, j=1
39//!              = dx * dy              for i=1, j=1
40//! ```
41//! - **Complexity**: O(4) per pixel (4 neighbor lookups)
42//! - **Quality**: Smooth, good for photos
43//! - **Use case**: Standard resizing, transformations
44//!
45//! ### Bicubic Interpolation
46//! ```text
47//! I_out(x, y) = Σᵢ₌₀³ Σⱼ₌₀³ w(i,j) * I_in(x+i-1, y+j-1)
48//!
49//! where w(i,j) uses cubic kernel:
50//! k(t) = { (a+2)|t|³ - (a+3)|t|² + 1           for |t| ≤ 1
51//!        { a|t|³ - 5a|t|² + 8a|t| - 4a         for 1 < |t| < 2
52//!        { 0                                    for |t| ≥ 2
53//!
54//! with a = -0.5 (most common)
55//! ```
56//! - **Complexity**: O(16) per pixel (4×4 neighborhood)
57//! - **Quality**: High quality, smooth gradients
58//! - **Use case**: High-quality resizing, professional graphics
59//!
60//! ## Color Space Conversions
61//!
62//! ### RGB to Grayscale
63//! ```text
64//! Y = 0.299 * R + 0.587 * G + 0.114 * B  (luminosity method)
65//! ```
66//!
67//! ### RGB to HSV
68//! ```text
69//! V = max(R, G, B)
70//! S = (V - min(R, G, B)) / V  if V ≠ 0, else 0
71//! H = { 60° × (G - B) / (V - min)        if V = R
72//!     { 60° × (2 + (B - R) / (V - min))  if V = G
73//!     { 60° × (4 + (R - G) / (V - min))  if V = B
74//! ```
75//!
76//! ## Filtering Operations
77//!
78//! ### Gaussian Filter
79//! ```text
80//! G(x, y) = (1 / (2πσ²)) * exp(-(x² + y²) / (2σ²))
81//!
82//! Kernel size: typically ⌈6σ⌉ to capture 99.7% of distribution
83//! ```
84//! - **Purpose**: Blur, noise reduction, anti-aliasing
85//! - **Properties**: Linear, separable, isotropic
86//!
87//! ### Sobel Filter
88//! ```text
89//! Gₓ = [-1  0  1]      Gᵧ = [-1 -2 -1]
90//!      [-2  0  2]           [ 0  0  0]
91//!      [-1  0  1]           [ 1  2  1]
92//!
93//! Magnitude = √(Gₓ² + Gᵧ²)
94//! Direction = atan2(Gᵧ, Gₓ)
95//! ```
96//! - **Purpose**: Edge detection, gradient computation
97//! - **Properties**: First derivative approximation
98//!
99//! ### Laplacian Filter
100//! ```text
101//! L = [ 0  1  0]      or    L = [ 1  1  1]
102//!     [ 1 -4  1]            [ 1 -8  1]
103//!     [ 0  1  0]            [ 1  1  1]
104//!
105//! ∇²I ≈ ∂²I/∂x² + ∂²I/∂y²
106//! ```
107//! - **Purpose**: Edge detection, feature enhancement
108//! - **Properties**: Second derivative, rotation invariant
109//!
110//! ## Morphological Operations
111//!
112//! ### Erosion
113//! ```text
114//! (A ⊖ B)(x, y) = min{A(x+i, y+j) | (i,j) ∈ B}
115//! ```
116//! - **Effect**: Shrinks bright regions, removes small objects
117//!
118//! ### Dilation
119//! ```text
120//! (A ⊕ B)(x, y) = max{A(x+i, y+j) | (i,j) ∈ B}
121//! ```
122//! - **Effect**: Expands bright regions, fills holes
123//!
124//! ### Opening
125//! ```text
126//! A ∘ B = (A ⊖ B) ⊕ B
127//! ```
128//! - **Effect**: Removes small objects while preserving shape
129//!
130//! ### Closing
131//! ```text
132//! A • B = (A ⊕ B) ⊖ B
133//! ```
134//! - **Effect**: Fills small holes while preserving shape
135//!
136//! # Performance Characteristics
137//!
138//! | Operation | Complexity | Memory | Notes |
139//! |-----------|------------|--------|-------|
140//! | Resize (nearest) | O(N×C×H×W) | O(output) | Fastest |
141//! | Resize (bilinear) | O(4×N×C×H×W) | O(output) | Good quality/speed |
142//! | Resize (bicubic) | O(16×N×C×H×W) | O(output) | Best quality |
143//! | Gaussian blur | O(N×C×H×W×k²) | O(output + kernel) | Separable: O(2k) |
144//! | Sobel | O(9×N×C×H×W) | O(output) | Fixed 3×3 kernel |
145//! | Color conversion | O(N×H×W) | O(output) | Element-wise |
146//! | Morphology | O(N×C×H×W×k²) | O(output) | Depends on structuring element |
147//!
148//! # Common Use Cases
149//!
150//! ## Data Augmentation
151//! ```rust
152//! use torsh_functional::image::{resize, InterpolationMode};
153//! use torsh_functional::random_ops::randn;
154//!
155//! fn example() -> Result<(), Box<dyn std::error::Error>> {
156//!     let image = randn(&[1, 3, 256, 256], None, None, None)?;
157//!
158//!     // Resize for different input sizes
159//!     let resized = resize(&image, (224, 224), InterpolationMode::Bilinear, true)?;
160//!     Ok(())
161//! }
162//! ```
163//!
164//! ## Preprocessing Pipelines
165//! ```rust
166//! use torsh_functional::image::gaussian_blur;
167//! use torsh_functional::random_ops::randn;
168//!
169//! fn example() -> Result<(), Box<dyn std::error::Error>> {
170//!     let image = randn(&[1, 3, 32, 32], None, None, None)?;
171//!
172//!     // Apply Gaussian blur for smoothing
173//!     let smoothed = gaussian_blur(&image, 3, 1.5)?;
174//!     Ok(())
175//! }
176//! ```
177//!
178//! ## Feature Extraction
179//! ```rust
180//! use torsh_functional::image::{sobel_filter, resize, SobelDirection, InterpolationMode};
181//! use torsh_functional::random_ops::randn;
182//!
183//! fn example() -> Result<(), Box<dyn std::error::Error>> {
184//!     let image = randn(&[1, 3, 256, 256], None, None, None)?;
185//!
186//!     // Edge detection for feature maps
187//!     let edges = sobel_filter(&image, SobelDirection::Both)?;
188//!
189//!     // Multi-scale analysis
190//!     let pyramid = vec![
191//!         resize(&image, (224, 224), InterpolationMode::Bilinear, false)?,
192//!         resize(&image, (112, 112), InterpolationMode::Bilinear, false)?,
193//!         resize(&image, (56, 56), InterpolationMode::Bilinear, false)?,
194//!     ];
195//!     Ok(())
196//! }
197//! ```
198//!
199//! # Advanced Algorithms
200//!
201//! ## Separable Filtering
202//!
203//! Many 2D filters can be decomposed into 1D operations:
204//! ```text
205//! K₂D = K₁D_vertical ⊗ K₁D_horizontal
206//!
207//! Complexity reduction: O(k²) → O(2k) per pixel
208//! ```
209//! **Examples**: Gaussian blur, box filter, motion blur
210//!
211//! ## Image Pyramid
212//!
213//! Multi-scale representation for coarse-to-fine processing:
214//! ```text
215//! Level 0: Original image I₀
216//! Level k: Iₖ = downsample(Iₖ₋₁) by factor 2
217//! ```
218//! **Applications**:
219//! - Object detection at multiple scales
220//! - Feature matching (SIFT, SURF)
221//! - Image blending (Laplacian pyramids)
222//!
223//! ## Integral Images (Summed Area Tables)
224//!
225//! Fast computation of rectangular region sums:
226//! ```text
227//! II(x, y) = Σᵢ≤ₓ Σⱼ≤y I(i, j)
228//!
229//! Rectangle sum = II(x₂,y₂) - II(x₁,y₂) - II(x₂,y₁) + II(x₁,y₁)
230//! ```
231//! **Complexity**: O(1) per query after O(HW) preprocessing
232//!
233//! **Applications**:
234//! - Box filtering
235//! - Haar-like features (face detection)
236//! - Adaptive thresholding
237//!
238//! ## Frequency Domain Processing
239//!
240//! Using Fourier transforms for global operations:
241//! ```text
242//! I_filtered = ℱ⁻¹(ℱ(I) · H)
243//! ```
244//! where H is the frequency domain filter.
245//!
246//! **Advantages**:
247//! - O(n log n) complexity for large kernels (via FFT)
248//! - Ideal for global operations (deconvolution, frequency-based filtering)
249//!
250//! ## Bilateral Filtering
251//!
252//! Edge-preserving smoothing:
253//! ```text
254//! BF(x) = (1/W) Σₚ G_σₛ(‖p-x‖) · G_σᵣ(|I(p)-I(x)|) · I(p)
255//! ```
256//! where:
257//! - G_σₛ: Spatial Gaussian (distance weight)
258//! - G_σᵣ: Range Gaussian (intensity similarity weight)
259//!
260//! **Properties**: Smooths while preserving edges
261//!
262//! ## Non-Maximum Suppression
263//!
264//! For edge thinning in edge detection:
265//! ```text
266//! Keep pixel if it's local maximum along gradient direction
267//! ```
268//! Essential step in Canny edge detection.
269//!
270//! # Computer Vision Applications
271//!
272//! ## Image Classification Preprocessing
273//! 1. Resize to fixed size (224×224 for ImageNet)
274//! 2. Normalize: μ=0, σ=1 per channel
275//! 3. Data augmentation: random crops, flips, color jitter
276//!
277//! ## Object Detection Preprocessing
278//! 1. Multi-scale processing (image pyramids)
279//! 2. Aspect ratio preservation with padding
280//! 3. Anchor-based or anchor-free coordinate systems
281//!
282//! ## Semantic Segmentation
283//! 1. High-resolution input preservation
284//! 2. Multi-scale feature extraction
285//! 3. Skip connections for fine details
286//!
287//! ## Style Transfer
288//! 1. Content loss: Feature matching in conv layers
289//! 2. Style loss: Gram matrix matching
290//! 3. Color space considerations (RGB vs LAB)
291//!
292//! # Best Practices
293//!
294//! 1. **Interpolation Selection**:
295//!    - Nearest: Masks, labels, pixel art, segmentation maps
296//!    - Bilinear: General purpose, good balance of speed/quality
297//!    - Bicubic: High quality, when quality matters more than speed
298//!    - Lanczos: Highest quality, slowest, professional graphics
299//!
300//! 2. **Anti-aliasing**:
301//!    - Always enable when downsampling > 2× to avoid Moiré patterns
302//!    - Use Gaussian pre-filtering for high-quality downsampling
303//!
304//! 3. **Color Space Selection**:
305//!    - RGB: Display, color manipulation, neural network input
306//!    - HSV: Hue/saturation adjustments, color-based segmentation
307//!    - LAB: Perceptually uniform, better for style transfer
308//!    - Grayscale: Edge detection, classical CV algorithms, faster processing
309//!
310//! 4. **Memory Efficiency**:
311//!    - Process in batches for large datasets
312//!    - Use in-place operations where possible
313//!    - Consider image pyramids for multi-scale processing
314//!
315//! 5. **Numerical Stability**:
316//!    - Normalize pixel values to [0, 1] or [-1, 1]
317//!    - Use double precision for accumulation in filters
318//!    - Clamp outputs to valid range after operations
319//!
320//! 6. **Padding Strategies**:
321//!    - Zero padding: Fast, but introduces boundary artifacts
322//!    - Reflection: Good for natural images
323//!    - Replication: Reduces boundary artifacts
324//!    - Circular: For periodic patterns
325//!
326//! 7. **Performance Optimization**:
327//!    - Use separable filters when possible
328//!    - Exploit SIMD for element-wise operations
329//!    - Pre-compute lookup tables for repeated operations
330//!    - Cache-friendly memory access patterns
331
332use torsh_core::{Result as TorshResult, TorshError};
333use torsh_tensor::Tensor;
334
335/// Resize tensor using different interpolation methods
336pub fn resize(
337    input: &Tensor,
338    size: (usize, usize),
339    mode: InterpolationMode,
340    antialias: bool,
341) -> TorshResult<Tensor> {
342    let shape = input.shape();
343    if shape.ndim() < 3 {
344        return Err(TorshError::invalid_argument_with_context(
345            "Input tensor must have at least 3 dimensions (C, H, W)",
346            "resize",
347        ));
348    }
349
350    let dims = shape.dims();
351    let channels = dims[dims.len() - 3];
352    let in_height = dims[dims.len() - 2];
353    let in_width = dims[dims.len() - 1];
354    let (out_height, out_width) = size;
355
356    // Handle batch dimensions
357    let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
358    let batch_size = batch_dims.iter().product::<usize>();
359
360    let input_data = input.to_vec()?;
361    let mut output_data = vec![0.0f32; batch_size * channels * out_height * out_width];
362
363    let scale_h = in_height as f32 / out_height as f32;
364    let scale_w = in_width as f32 / out_width as f32;
365
366    match mode {
367        InterpolationMode::Nearest => {
368            for b in 0..batch_size {
369                for c in 0..channels {
370                    for oh in 0..out_height {
371                        for ow in 0..out_width {
372                            let ih = (oh as f32 * scale_h).round() as usize;
373                            let iw = (ow as f32 * scale_w).round() as usize;
374
375                            let ih = ih.min(in_height - 1);
376                            let iw = iw.min(in_width - 1);
377
378                            let in_idx = ((b * channels + c) * in_height + ih) * in_width + iw;
379                            let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
380
381                            output_data[out_idx] = input_data[in_idx];
382                        }
383                    }
384                }
385            }
386        }
387        InterpolationMode::Bilinear => {
388            for b in 0..batch_size {
389                for c in 0..channels {
390                    for oh in 0..out_height {
391                        for ow in 0..out_width {
392                            let fh = (oh as f32 + 0.5) * scale_h - 0.5;
393                            let fw = (ow as f32 + 0.5) * scale_w - 0.5;
394
395                            let ih_low = fh.floor() as i32;
396                            let iw_low = fw.floor() as i32;
397                            let _ih_high = ih_low + 1;
398                            let _iw_high = iw_low + 1;
399
400                            let wh = fh - ih_low as f32;
401                            let ww = fw - iw_low as f32;
402
403                            let mut value = 0.0f32;
404
405                            // Bilinear interpolation
406                            for dh in 0..2 {
407                                for dw in 0..2 {
408                                    let ih = ih_low + dh;
409                                    let iw = iw_low + dw;
410
411                                    if ih >= 0
412                                        && ih < in_height as i32
413                                        && iw >= 0
414                                        && iw < in_width as i32
415                                    {
416                                        let weight = if dh == 0 { 1.0 - wh } else { wh }
417                                            * if dw == 0 { 1.0 - ww } else { ww };
418
419                                        let in_idx = ((b * channels + c) * in_height + ih as usize)
420                                            * in_width
421                                            + iw as usize;
422                                        value += weight * input_data[in_idx];
423                                    }
424                                }
425                            }
426
427                            let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
428                            output_data[out_idx] = value;
429                        }
430                    }
431                }
432            }
433        }
434        InterpolationMode::Bicubic | InterpolationMode::Area => {
435            // Simplified implementation - use bilinear for now
436            return resize(input, size, InterpolationMode::Bilinear, antialias);
437        }
438    }
439
440    let mut output_shape = batch_dims;
441    output_shape.extend_from_slice(&[channels, out_height, out_width]);
442
443    Tensor::from_data(output_data, output_shape, input.device())
444}
445
446/// Interpolation modes for resizing
447#[derive(Debug, Clone, Copy)]
448pub enum InterpolationMode {
449    Nearest,
450    Bilinear,
451    Bicubic,
452    Area,
453}
454
455/// Apply Gaussian blur to image tensor
456pub fn gaussian_blur(input: &Tensor, kernel_size: usize, sigma: f32) -> TorshResult<Tensor> {
457    let shape = input.shape();
458    if shape.ndim() < 3 {
459        return Err(TorshError::invalid_argument_with_context(
460            "Input tensor must have at least 3 dimensions (C, H, W)",
461            "gaussian_blur",
462        ));
463    }
464
465    // Create Gaussian kernel
466    let radius = kernel_size / 2;
467    let mut kernel = vec![0.0f32; kernel_size * kernel_size];
468    let mut sum = 0.0f32;
469
470    for i in 0..kernel_size {
471        for j in 0..kernel_size {
472            let x = i as i32 - radius as i32;
473            let y = j as i32 - radius as i32;
474            let val = (-((x * x + y * y) as f32) / (2.0 * sigma * sigma)).exp();
475            kernel[i * kernel_size + j] = val;
476            sum += val;
477        }
478    }
479
480    // Normalize kernel
481    for val in &mut kernel {
482        *val /= sum;
483    }
484
485    // Apply convolution with the Gaussian kernel
486    apply_convolution(input, &kernel, kernel_size, 1, radius)
487}
488
489/// Apply Sobel edge detection
490pub fn sobel_filter(input: &Tensor, direction: SobelDirection) -> TorshResult<Tensor> {
491    let kernel = match direction {
492        SobelDirection::X => vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0],
493        SobelDirection::Y => vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0],
494        SobelDirection::Both => {
495            // For both directions, compute magnitude
496            let x_result = sobel_filter(input, SobelDirection::X)?;
497            let y_result = sobel_filter(input, SobelDirection::Y)?;
498            return compute_gradient_magnitude(&x_result, &y_result);
499        }
500    };
501
502    apply_convolution(input, &kernel, 3, 1, 1)
503}
504
505/// Sobel filter directions
506#[derive(Debug, Clone, Copy)]
507pub enum SobelDirection {
508    X,
509    Y,
510    Both,
511}
512
513/// Apply Laplacian filter for edge detection
514pub fn laplacian_filter(input: &Tensor) -> TorshResult<Tensor> {
515    let kernel = vec![0.0, -1.0, 0.0, -1.0, 4.0, -1.0, 0.0, -1.0, 0.0];
516
517    apply_convolution(input, &kernel, 3, 1, 1)
518}
519
520/// Apply morphological erosion
521pub fn erosion(input: &Tensor, kernel_size: usize, iterations: usize) -> TorshResult<Tensor> {
522    let mut result = input.clone();
523
524    for _ in 0..iterations {
525        result = apply_morphological_op(&result, kernel_size, MorphOp::Erosion)?;
526    }
527
528    Ok(result)
529}
530
531/// Apply morphological dilation
532pub fn dilation(input: &Tensor, kernel_size: usize, iterations: usize) -> TorshResult<Tensor> {
533    let mut result = input.clone();
534
535    for _ in 0..iterations {
536        result = apply_morphological_op(&result, kernel_size, MorphOp::Dilation)?;
537    }
538
539    Ok(result)
540}
541
542/// Apply morphological opening (erosion followed by dilation)
543pub fn opening(input: &Tensor, kernel_size: usize) -> TorshResult<Tensor> {
544    let eroded = erosion(input, kernel_size, 1)?;
545    dilation(&eroded, kernel_size, 1)
546}
547
548/// Apply morphological closing (dilation followed by erosion)
549pub fn closing(input: &Tensor, kernel_size: usize) -> TorshResult<Tensor> {
550    let dilated = dilation(input, kernel_size, 1)?;
551    erosion(&dilated, kernel_size, 1)
552}
553
554/// Convert RGB to HSV color space
555pub fn rgb_to_hsv(input: &Tensor) -> TorshResult<Tensor> {
556    let shape = input.shape();
557    if shape.ndim() < 3 {
558        return Err(TorshError::invalid_argument_with_context(
559            "Input tensor must have at least 3 dimensions (C, H, W)",
560            "rgb_to_hsv",
561        ));
562    }
563
564    let dims = shape.dims();
565    if dims[dims.len() - 3] != 3 {
566        return Err(TorshError::invalid_argument_with_context(
567            "Input tensor must have 3 channels for RGB",
568            "rgb_to_hsv",
569        ));
570    }
571
572    let input_data = input.to_vec()?;
573    let mut output_data = vec![0.0f32; input_data.len()];
574
575    let batch_size = dims[..dims.len() - 3].iter().product::<usize>();
576    let height = dims[dims.len() - 2];
577    let width = dims[dims.len() - 1];
578
579    for b in 0..batch_size {
580        for h in 0..height {
581            for w in 0..width {
582                let r_idx = ((b * 3 + 0) * height + h) * width + w;
583                let g_idx = ((b * 3 + 1) * height + h) * width + w;
584                let b_idx = ((b * 3 + 2) * height + h) * width + w;
585
586                let r = input_data[r_idx];
587                let g = input_data[g_idx];
588                let b_val = input_data[b_idx];
589
590                let max_val = r.max(g).max(b_val);
591                let min_val = r.min(g).min(b_val);
592                let delta = max_val - min_val;
593
594                // Value
595                let v = max_val;
596
597                // Saturation
598                let s = if max_val == 0.0 { 0.0 } else { delta / max_val };
599
600                // Hue
601                let h_val = if delta == 0.0 {
602                    0.0
603                } else if max_val == r {
604                    60.0 * (((g - b_val) / delta) % 6.0)
605                } else if max_val == g {
606                    60.0 * ((b_val - r) / delta + 2.0)
607                } else {
608                    60.0 * ((r - g) / delta + 4.0)
609                };
610
611                output_data[r_idx] = h_val / 360.0; // Normalize hue to [0, 1]
612                output_data[g_idx] = s;
613                output_data[b_idx] = v;
614            }
615        }
616    }
617
618    Tensor::from_data(output_data, dims.to_vec(), input.device())
619}
620
621/// Convert HSV to RGB color space
622pub fn hsv_to_rgb(input: &Tensor) -> TorshResult<Tensor> {
623    let shape = input.shape();
624    if shape.ndim() < 3 {
625        return Err(TorshError::invalid_argument_with_context(
626            "Input tensor must have at least 3 dimensions (C, H, W)",
627            "hsv_to_rgb",
628        ));
629    }
630
631    let dims = shape.dims();
632    if dims[dims.len() - 3] != 3 {
633        return Err(TorshError::invalid_argument_with_context(
634            "Input tensor must have 3 channels for HSV",
635            "hsv_to_rgb",
636        ));
637    }
638
639    let input_data = input.to_vec()?;
640    let mut output_data = vec![0.0f32; input_data.len()];
641
642    let batch_size = dims[..dims.len() - 3].iter().product::<usize>();
643    let height = dims[dims.len() - 2];
644    let width = dims[dims.len() - 1];
645
646    for b in 0..batch_size {
647        for h in 0..height {
648            for w in 0..width {
649                let h_idx = ((b * 3 + 0) * height + h) * width + w;
650                let s_idx = ((b * 3 + 1) * height + h) * width + w;
651                let v_idx = ((b * 3 + 2) * height + h) * width + w;
652
653                let h_val = input_data[h_idx] * 360.0; // Denormalize hue
654                let s = input_data[s_idx];
655                let v = input_data[v_idx];
656
657                let c = v * s;
658                let x = c * (1.0 - ((h_val / 60.0) % 2.0 - 1.0).abs());
659                let m = v - c;
660
661                let (r_prime, g_prime, b_prime) = if h_val < 60.0 {
662                    (c, x, 0.0)
663                } else if h_val < 120.0 {
664                    (x, c, 0.0)
665                } else if h_val < 180.0 {
666                    (0.0, c, x)
667                } else if h_val < 240.0 {
668                    (0.0, x, c)
669                } else if h_val < 300.0 {
670                    (x, 0.0, c)
671                } else {
672                    (c, 0.0, x)
673                };
674
675                output_data[h_idx] = r_prime + m;
676                output_data[s_idx] = g_prime + m;
677                output_data[v_idx] = b_prime + m;
678            }
679        }
680    }
681
682    Tensor::from_data(output_data, dims.to_vec(), input.device())
683}
684
685/// Apply affine transformation to image
686pub fn affine_transform(
687    input: &Tensor,
688    matrix: &[f32; 6], // [a, b, c, d, e, f] for transformation [[a, b, c], [d, e, f]]
689    output_size: Option<(usize, usize)>,
690    fill_value: f32,
691) -> TorshResult<Tensor> {
692    let shape = input.shape();
693    if shape.ndim() < 3 {
694        return Err(TorshError::invalid_argument_with_context(
695            "Input tensor must have at least 3 dimensions (C, H, W)",
696            "affine_transform",
697        ));
698    }
699
700    let dims = shape.dims();
701    let channels = dims[dims.len() - 3];
702    let in_height = dims[dims.len() - 2];
703    let in_width = dims[dims.len() - 1];
704
705    let (out_height, out_width) = output_size.unwrap_or((in_height, in_width));
706
707    // Compute inverse transformation matrix for backward mapping
708    let [a, b, c, d, e, f] = *matrix;
709    let det = a * e - b * d;
710
711    if det.abs() < 1e-6 {
712        return Err(TorshError::invalid_argument_with_context(
713            "Affine transformation matrix is singular",
714            "affine_transform",
715        ));
716    }
717
718    let inv_det = 1.0 / det;
719    let inv_a = e * inv_det;
720    let inv_b = -b * inv_det;
721    let inv_c = (b * f - c * e) * inv_det;
722    let inv_d = -d * inv_det;
723    let inv_e = a * inv_det;
724    let inv_f = (c * d - a * f) * inv_det;
725
726    let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
727    let batch_size = batch_dims.iter().product::<usize>();
728
729    let input_data = input.to_vec()?;
730    let mut output_data = vec![fill_value; batch_size * channels * out_height * out_width];
731
732    for b in 0..batch_size {
733        for c in 0..channels {
734            for oh in 0..out_height {
735                for ow in 0..out_width {
736                    // Apply inverse transformation
737                    let x_out = ow as f32;
738                    let y_out = oh as f32;
739
740                    let x_in = inv_a * x_out + inv_b * y_out + inv_c;
741                    let y_in = inv_d * x_out + inv_e * y_out + inv_f;
742
743                    // Bilinear interpolation
744                    if x_in >= 0.0
745                        && x_in < in_width as f32 - 1.0
746                        && y_in >= 0.0
747                        && y_in < in_height as f32 - 1.0
748                    {
749                        let x0 = x_in.floor() as usize;
750                        let y0 = y_in.floor() as usize;
751                        let x1 = x0 + 1;
752                        let y1 = y0 + 1;
753
754                        let wx = x_in - x0 as f32;
755                        let wy = y_in - y0 as f32;
756
757                        let idx00 = ((b * channels + c) * in_height + y0) * in_width + x0;
758                        let idx01 = ((b * channels + c) * in_height + y0) * in_width + x1;
759                        let idx10 = ((b * channels + c) * in_height + y1) * in_width + x0;
760                        let idx11 = ((b * channels + c) * in_height + y1) * in_width + x1;
761
762                        let val = (1.0 - wx) * (1.0 - wy) * input_data[idx00]
763                            + wx * (1.0 - wy) * input_data[idx01]
764                            + (1.0 - wx) * wy * input_data[idx10]
765                            + wx * wy * input_data[idx11];
766
767                        let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
768                        output_data[out_idx] = val;
769                    }
770                }
771            }
772        }
773    }
774
775    let mut output_shape = batch_dims;
776    output_shape.extend_from_slice(&[channels, out_height, out_width]);
777
778    Tensor::from_data(output_data, output_shape, input.device())
779}
780
781// Helper functions
782
783/// Apply convolution with a given kernel
784fn apply_convolution(
785    input: &Tensor,
786    kernel: &[f32],
787    kernel_size: usize,
788    stride: usize,
789    padding: usize,
790) -> TorshResult<Tensor> {
791    let shape = input.shape();
792    let dims = shape.dims();
793
794    let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
795    let batch_size = batch_dims.iter().product::<usize>();
796    let channels = dims[dims.len() - 3];
797    let in_height = dims[dims.len() - 2];
798    let in_width = dims[dims.len() - 1];
799
800    let out_height = (in_height + 2 * padding - kernel_size) / stride + 1;
801    let out_width = (in_width + 2 * padding - kernel_size) / stride + 1;
802
803    let input_data = input.to_vec()?;
804    let mut output_data = vec![0.0f32; batch_size * channels * out_height * out_width];
805
806    for b in 0..batch_size {
807        for c in 0..channels {
808            for oh in 0..out_height {
809                for ow in 0..out_width {
810                    let mut sum = 0.0f32;
811
812                    for kh in 0..kernel_size {
813                        for kw in 0..kernel_size {
814                            let ih = oh * stride + kh;
815                            let iw = ow * stride + kw;
816
817                            if ih >= padding
818                                && ih < in_height + padding
819                                && iw >= padding
820                                && iw < in_width + padding
821                            {
822                                let real_ih = ih - padding;
823                                let real_iw = iw - padding;
824
825                                if real_ih < in_height && real_iw < in_width {
826                                    let in_idx = ((b * channels + c) * in_height + real_ih)
827                                        * in_width
828                                        + real_iw;
829                                    let kernel_idx = kh * kernel_size + kw;
830                                    sum += input_data[in_idx] * kernel[kernel_idx];
831                                }
832                            }
833                        }
834                    }
835
836                    let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
837                    output_data[out_idx] = sum;
838                }
839            }
840        }
841    }
842
843    let mut output_shape = batch_dims;
844    output_shape.extend_from_slice(&[channels, out_height, out_width]);
845
846    Tensor::from_data(output_data, output_shape, input.device())
847}
848
849/// Morphological operation types
850#[derive(Debug, Clone, Copy)]
851enum MorphOp {
852    Erosion,
853    Dilation,
854}
855
856/// Apply morphological operation
857fn apply_morphological_op(input: &Tensor, kernel_size: usize, op: MorphOp) -> TorshResult<Tensor> {
858    let shape = input.shape();
859    let dims = shape.dims();
860
861    let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
862    let batch_size = batch_dims.iter().product::<usize>();
863    let channels = dims[dims.len() - 3];
864    let height = dims[dims.len() - 2];
865    let width = dims[dims.len() - 1];
866
867    let radius = kernel_size / 2;
868    let input_data = input.to_vec()?;
869    let mut output_data = vec![0.0f32; input_data.len()];
870
871    for b in 0..batch_size {
872        for c in 0..channels {
873            for h in 0..height {
874                for w in 0..width {
875                    let mut result = match op {
876                        MorphOp::Erosion => f32::INFINITY,
877                        MorphOp::Dilation => f32::NEG_INFINITY,
878                    };
879
880                    for kh in 0..kernel_size {
881                        for kw in 0..kernel_size {
882                            let ih = h as i32 + kh as i32 - radius as i32;
883                            let iw = w as i32 + kw as i32 - radius as i32;
884
885                            if ih >= 0 && ih < height as i32 && iw >= 0 && iw < width as i32 {
886                                let in_idx = ((b * channels + c) * height + ih as usize) * width
887                                    + iw as usize;
888                                let val = input_data[in_idx];
889
890                                match op {
891                                    MorphOp::Erosion => result = result.min(val),
892                                    MorphOp::Dilation => result = result.max(val),
893                                }
894                            }
895                        }
896                    }
897
898                    let out_idx = ((b * channels + c) * height + h) * width + w;
899                    output_data[out_idx] = result;
900                }
901            }
902        }
903    }
904
905    Tensor::from_data(output_data, dims.to_vec(), input.device())
906}
907
908/// Compute gradient magnitude from X and Y gradients
909fn compute_gradient_magnitude(grad_x: &Tensor, grad_y: &Tensor) -> TorshResult<Tensor> {
910    let grad_x_data = grad_x.to_vec()?;
911    let grad_y_data = grad_y.to_vec()?;
912
913    let magnitude_data: Vec<f32> = grad_x_data
914        .iter()
915        .zip(grad_y_data.iter())
916        .map(|(&gx, &gy)| (gx * gx + gy * gy).sqrt())
917        .collect();
918
919    Tensor::from_data(
920        magnitude_data,
921        grad_x.shape().dims().to_vec(),
922        grad_x.device(),
923    )
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929    use torsh_tensor::creation::ones;
930
931    #[test]
932    fn test_resize_nearest() {
933        let input = ones(&[1, 3, 4, 4]).unwrap(); // Batch=1, Channels=3, Height=4, Width=4
934        let result = resize(&input, (2, 2), InterpolationMode::Nearest, false).unwrap();
935        assert_eq!(result.shape().dims(), &[1, 3, 2, 2]);
936    }
937
938    #[test]
939    fn test_gaussian_blur() {
940        let input = ones(&[1, 3, 5, 5]).unwrap();
941        let result = gaussian_blur(&input, 3, 1.0).unwrap();
942        assert_eq!(result.shape().dims(), &[1, 3, 5, 5]); // Size maintained due to padding
943    }
944
945    #[test]
946    fn test_sobel_filter() {
947        let input = ones(&[1, 1, 5, 5]).unwrap();
948        let result = sobel_filter(&input, SobelDirection::X).unwrap();
949        assert_eq!(result.shape().dims(), &[1, 1, 5, 5]); // Size maintained due to padding
950    }
951
952    #[test]
953    fn test_rgb_to_hsv_conversion() {
954        let input = ones(&[1, 3, 2, 2]).unwrap(); // RGB image
955        let hsv = rgb_to_hsv(&input).unwrap();
956        let rgb_back = hsv_to_rgb(&hsv).unwrap();
957        assert_eq!(rgb_back.shape().dims(), &[1, 3, 2, 2]);
958    }
959
960    #[test]
961    fn test_morphological_operations() {
962        let input = ones(&[1, 1, 5, 5]).unwrap();
963        let eroded = erosion(&input, 3, 1).unwrap();
964        let dilated = dilation(&input, 3, 1).unwrap();
965        assert_eq!(eroded.shape().dims(), &[1, 1, 5, 5]);
966        assert_eq!(dilated.shape().dims(), &[1, 1, 5, 5]);
967    }
968}