torsh_functional/image.rs
1//! Image Processing Operations
2//!
3//! This module provides comprehensive image processing operations for computer vision
4//! and deep learning applications. All operations are designed to work with tensors
5//! in standard image formats (CHW or NCHW).
6//!
7//! # Mathematical Foundation
8//!
9//! ## Image Representation
10//!
11//! Images are represented as tensors with dimensions:
12//! - **2D**: `[H, W]` - Grayscale images
13//! - **3D**: `[C, H, W]` - Multi-channel images (RGB, etc.)
14//! - **4D**: `[N, C, H, W]` - Batched images
15//!
16//! where:
17//! - `N` = batch size
18//! - `C` = number of channels
19//! - `H` = height (rows)
20//! - `W` = width (columns)
21//!
22//! ## Interpolation Methods
23//!
24//! ### Nearest Neighbor
25//! ```text
26//! I_out(x, y) = I_in(round(x * scale_x), round(y * scale_y))
27//! ```
28//! - **Complexity**: O(1) per pixel
29//! - **Quality**: Blocky, preserves sharp edges
30//! - **Use case**: Fast resizing, pixel art
31//!
32//! ### Bilinear Interpolation
33//! ```text
34//! I_out(x, y) = Σᵢ Σⱼ w(i,j) * I_in(x+i, y+j)
35//!
36//! where w(i,j) = (1 - dx) * (1 - dy) for i=0, j=0
37//! = dx * (1 - dy) for i=1, j=0
38//! = (1 - dx) * dy for i=0, j=1
39//! = dx * dy for i=1, j=1
40//! ```
41//! - **Complexity**: O(4) per pixel (4 neighbor lookups)
42//! - **Quality**: Smooth, good for photos
43//! - **Use case**: Standard resizing, transformations
44//!
45//! ### Bicubic Interpolation
46//! ```text
47//! I_out(x, y) = Σᵢ₌₀³ Σⱼ₌₀³ w(i,j) * I_in(x+i-1, y+j-1)
48//!
49//! where w(i,j) uses cubic kernel:
50//! k(t) = { (a+2)|t|³ - (a+3)|t|² + 1 for |t| ≤ 1
51//! { a|t|³ - 5a|t|² + 8a|t| - 4a for 1 < |t| < 2
52//! { 0 for |t| ≥ 2
53//!
54//! with a = -0.5 (most common)
55//! ```
56//! - **Complexity**: O(16) per pixel (4×4 neighborhood)
57//! - **Quality**: High quality, smooth gradients
58//! - **Use case**: High-quality resizing, professional graphics
59//!
60//! ## Color Space Conversions
61//!
62//! ### RGB to Grayscale
63//! ```text
64//! Y = 0.299 * R + 0.587 * G + 0.114 * B (luminosity method)
65//! ```
66//!
67//! ### RGB to HSV
68//! ```text
69//! V = max(R, G, B)
70//! S = (V - min(R, G, B)) / V if V ≠ 0, else 0
71//! H = { 60° × (G - B) / (V - min) if V = R
72//! { 60° × (2 + (B - R) / (V - min)) if V = G
73//! { 60° × (4 + (R - G) / (V - min)) if V = B
74//! ```
75//!
76//! ## Filtering Operations
77//!
78//! ### Gaussian Filter
79//! ```text
80//! G(x, y) = (1 / (2πσ²)) * exp(-(x² + y²) / (2σ²))
81//!
82//! Kernel size: typically ⌈6σ⌉ to capture 99.7% of distribution
83//! ```
84//! - **Purpose**: Blur, noise reduction, anti-aliasing
85//! - **Properties**: Linear, separable, isotropic
86//!
87//! ### Sobel Filter
88//! ```text
89//! Gₓ = [-1 0 1] Gᵧ = [-1 -2 -1]
90//! [-2 0 2] [ 0 0 0]
91//! [-1 0 1] [ 1 2 1]
92//!
93//! Magnitude = √(Gₓ² + Gᵧ²)
94//! Direction = atan2(Gᵧ, Gₓ)
95//! ```
96//! - **Purpose**: Edge detection, gradient computation
97//! - **Properties**: First derivative approximation
98//!
99//! ### Laplacian Filter
100//! ```text
101//! L = [ 0 1 0] or L = [ 1 1 1]
102//! [ 1 -4 1] [ 1 -8 1]
103//! [ 0 1 0] [ 1 1 1]
104//!
105//! ∇²I ≈ ∂²I/∂x² + ∂²I/∂y²
106//! ```
107//! - **Purpose**: Edge detection, feature enhancement
108//! - **Properties**: Second derivative, rotation invariant
109//!
110//! ## Morphological Operations
111//!
112//! ### Erosion
113//! ```text
114//! (A ⊖ B)(x, y) = min{A(x+i, y+j) | (i,j) ∈ B}
115//! ```
116//! - **Effect**: Shrinks bright regions, removes small objects
117//!
118//! ### Dilation
119//! ```text
120//! (A ⊕ B)(x, y) = max{A(x+i, y+j) | (i,j) ∈ B}
121//! ```
122//! - **Effect**: Expands bright regions, fills holes
123//!
124//! ### Opening
125//! ```text
126//! A ∘ B = (A ⊖ B) ⊕ B
127//! ```
128//! - **Effect**: Removes small objects while preserving shape
129//!
130//! ### Closing
131//! ```text
132//! A • B = (A ⊕ B) ⊖ B
133//! ```
134//! - **Effect**: Fills small holes while preserving shape
135//!
136//! # Performance Characteristics
137//!
138//! | Operation | Complexity | Memory | Notes |
139//! |-----------|------------|--------|-------|
140//! | Resize (nearest) | O(N×C×H×W) | O(output) | Fastest |
141//! | Resize (bilinear) | O(4×N×C×H×W) | O(output) | Good quality/speed |
142//! | Resize (bicubic) | O(16×N×C×H×W) | O(output) | Best quality |
143//! | Gaussian blur | O(N×C×H×W×k²) | O(output + kernel) | Separable: O(2k) |
144//! | Sobel | O(9×N×C×H×W) | O(output) | Fixed 3×3 kernel |
145//! | Color conversion | O(N×H×W) | O(output) | Element-wise |
146//! | Morphology | O(N×C×H×W×k²) | O(output) | Depends on structuring element |
147//!
148//! # Common Use Cases
149//!
150//! ## Data Augmentation
151//! ```rust
152//! use torsh_functional::image::{resize, InterpolationMode};
153//! use torsh_functional::random_ops::randn;
154//!
155//! fn example() -> Result<(), Box<dyn std::error::Error>> {
156//! let image = randn(&[1, 3, 256, 256], None, None, None)?;
157//!
158//! // Resize for different input sizes
159//! let resized = resize(&image, (224, 224), InterpolationMode::Bilinear, true)?;
160//! Ok(())
161//! }
162//! ```
163//!
164//! ## Preprocessing Pipelines
165//! ```rust
166//! use torsh_functional::image::gaussian_blur;
167//! use torsh_functional::random_ops::randn;
168//!
169//! fn example() -> Result<(), Box<dyn std::error::Error>> {
170//! let image = randn(&[1, 3, 32, 32], None, None, None)?;
171//!
172//! // Apply Gaussian blur for smoothing
173//! let smoothed = gaussian_blur(&image, 3, 1.5)?;
174//! Ok(())
175//! }
176//! ```
177//!
178//! ## Feature Extraction
179//! ```rust
180//! use torsh_functional::image::{sobel_filter, resize, SobelDirection, InterpolationMode};
181//! use torsh_functional::random_ops::randn;
182//!
183//! fn example() -> Result<(), Box<dyn std::error::Error>> {
184//! let image = randn(&[1, 3, 256, 256], None, None, None)?;
185//!
186//! // Edge detection for feature maps
187//! let edges = sobel_filter(&image, SobelDirection::Both)?;
188//!
189//! // Multi-scale analysis
190//! let pyramid = vec![
191//! resize(&image, (224, 224), InterpolationMode::Bilinear, false)?,
192//! resize(&image, (112, 112), InterpolationMode::Bilinear, false)?,
193//! resize(&image, (56, 56), InterpolationMode::Bilinear, false)?,
194//! ];
195//! Ok(())
196//! }
197//! ```
198//!
199//! # Advanced Algorithms
200//!
201//! ## Separable Filtering
202//!
203//! Many 2D filters can be decomposed into 1D operations:
204//! ```text
205//! K₂D = K₁D_vertical ⊗ K₁D_horizontal
206//!
207//! Complexity reduction: O(k²) → O(2k) per pixel
208//! ```
209//! **Examples**: Gaussian blur, box filter, motion blur
210//!
211//! ## Image Pyramid
212//!
213//! Multi-scale representation for coarse-to-fine processing:
214//! ```text
215//! Level 0: Original image I₀
216//! Level k: Iₖ = downsample(Iₖ₋₁) by factor 2
217//! ```
218//! **Applications**:
219//! - Object detection at multiple scales
220//! - Feature matching (SIFT, SURF)
221//! - Image blending (Laplacian pyramids)
222//!
223//! ## Integral Images (Summed Area Tables)
224//!
225//! Fast computation of rectangular region sums:
226//! ```text
227//! II(x, y) = Σᵢ≤ₓ Σⱼ≤y I(i, j)
228//!
229//! Rectangle sum = II(x₂,y₂) - II(x₁,y₂) - II(x₂,y₁) + II(x₁,y₁)
230//! ```
231//! **Complexity**: O(1) per query after O(HW) preprocessing
232//!
233//! **Applications**:
234//! - Box filtering
235//! - Haar-like features (face detection)
236//! - Adaptive thresholding
237//!
238//! ## Frequency Domain Processing
239//!
240//! Using Fourier transforms for global operations:
241//! ```text
242//! I_filtered = ℱ⁻¹(ℱ(I) · H)
243//! ```
244//! where H is the frequency domain filter.
245//!
246//! **Advantages**:
247//! - O(n log n) complexity for large kernels (via FFT)
248//! - Ideal for global operations (deconvolution, frequency-based filtering)
249//!
250//! ## Bilateral Filtering
251//!
252//! Edge-preserving smoothing:
253//! ```text
254//! BF(x) = (1/W) Σₚ G_σₛ(‖p-x‖) · G_σᵣ(|I(p)-I(x)|) · I(p)
255//! ```
256//! where:
257//! - G_σₛ: Spatial Gaussian (distance weight)
258//! - G_σᵣ: Range Gaussian (intensity similarity weight)
259//!
260//! **Properties**: Smooths while preserving edges
261//!
262//! ## Non-Maximum Suppression
263//!
264//! For edge thinning in edge detection:
265//! ```text
266//! Keep pixel if it's local maximum along gradient direction
267//! ```
268//! Essential step in Canny edge detection.
269//!
270//! # Computer Vision Applications
271//!
272//! ## Image Classification Preprocessing
273//! 1. Resize to fixed size (224×224 for ImageNet)
274//! 2. Normalize: μ=0, σ=1 per channel
275//! 3. Data augmentation: random crops, flips, color jitter
276//!
277//! ## Object Detection Preprocessing
278//! 1. Multi-scale processing (image pyramids)
279//! 2. Aspect ratio preservation with padding
280//! 3. Anchor-based or anchor-free coordinate systems
281//!
282//! ## Semantic Segmentation
283//! 1. High-resolution input preservation
284//! 2. Multi-scale feature extraction
285//! 3. Skip connections for fine details
286//!
287//! ## Style Transfer
288//! 1. Content loss: Feature matching in conv layers
289//! 2. Style loss: Gram matrix matching
290//! 3. Color space considerations (RGB vs LAB)
291//!
292//! # Best Practices
293//!
294//! 1. **Interpolation Selection**:
295//! - Nearest: Masks, labels, pixel art, segmentation maps
296//! - Bilinear: General purpose, good balance of speed/quality
297//! - Bicubic: High quality, when quality matters more than speed
298//! - Lanczos: Highest quality, slowest, professional graphics
299//!
300//! 2. **Anti-aliasing**:
301//! - Always enable when downsampling > 2× to avoid Moiré patterns
302//! - Use Gaussian pre-filtering for high-quality downsampling
303//!
304//! 3. **Color Space Selection**:
305//! - RGB: Display, color manipulation, neural network input
306//! - HSV: Hue/saturation adjustments, color-based segmentation
307//! - LAB: Perceptually uniform, better for style transfer
308//! - Grayscale: Edge detection, classical CV algorithms, faster processing
309//!
310//! 4. **Memory Efficiency**:
311//! - Process in batches for large datasets
312//! - Use in-place operations where possible
313//! - Consider image pyramids for multi-scale processing
314//!
315//! 5. **Numerical Stability**:
316//! - Normalize pixel values to [0, 1] or [-1, 1]
317//! - Use double precision for accumulation in filters
318//! - Clamp outputs to valid range after operations
319//!
320//! 6. **Padding Strategies**:
321//! - Zero padding: Fast, but introduces boundary artifacts
322//! - Reflection: Good for natural images
323//! - Replication: Reduces boundary artifacts
324//! - Circular: For periodic patterns
325//!
326//! 7. **Performance Optimization**:
327//! - Use separable filters when possible
328//! - Exploit SIMD for element-wise operations
329//! - Pre-compute lookup tables for repeated operations
330//! - Cache-friendly memory access patterns
331
332use torsh_core::{Result as TorshResult, TorshError};
333use torsh_tensor::Tensor;
334
335/// Resize tensor using different interpolation methods
336pub fn resize(
337 input: &Tensor,
338 size: (usize, usize),
339 mode: InterpolationMode,
340 antialias: bool,
341) -> TorshResult<Tensor> {
342 let shape = input.shape();
343 if shape.ndim() < 3 {
344 return Err(TorshError::invalid_argument_with_context(
345 "Input tensor must have at least 3 dimensions (C, H, W)",
346 "resize",
347 ));
348 }
349
350 let dims = shape.dims();
351 let channels = dims[dims.len() - 3];
352 let in_height = dims[dims.len() - 2];
353 let in_width = dims[dims.len() - 1];
354 let (out_height, out_width) = size;
355
356 // Handle batch dimensions
357 let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
358 let batch_size = batch_dims.iter().product::<usize>();
359
360 let input_data = input.to_vec()?;
361 let mut output_data = vec![0.0f32; batch_size * channels * out_height * out_width];
362
363 let scale_h = in_height as f32 / out_height as f32;
364 let scale_w = in_width as f32 / out_width as f32;
365
366 match mode {
367 InterpolationMode::Nearest => {
368 for b in 0..batch_size {
369 for c in 0..channels {
370 for oh in 0..out_height {
371 for ow in 0..out_width {
372 let ih = (oh as f32 * scale_h).round() as usize;
373 let iw = (ow as f32 * scale_w).round() as usize;
374
375 let ih = ih.min(in_height - 1);
376 let iw = iw.min(in_width - 1);
377
378 let in_idx = ((b * channels + c) * in_height + ih) * in_width + iw;
379 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
380
381 output_data[out_idx] = input_data[in_idx];
382 }
383 }
384 }
385 }
386 }
387 InterpolationMode::Bilinear => {
388 for b in 0..batch_size {
389 for c in 0..channels {
390 for oh in 0..out_height {
391 for ow in 0..out_width {
392 let fh = (oh as f32 + 0.5) * scale_h - 0.5;
393 let fw = (ow as f32 + 0.5) * scale_w - 0.5;
394
395 let ih_low = fh.floor() as i32;
396 let iw_low = fw.floor() as i32;
397 let _ih_high = ih_low + 1;
398 let _iw_high = iw_low + 1;
399
400 let wh = fh - ih_low as f32;
401 let ww = fw - iw_low as f32;
402
403 let mut value = 0.0f32;
404
405 // Bilinear interpolation
406 for dh in 0..2 {
407 for dw in 0..2 {
408 let ih = ih_low + dh;
409 let iw = iw_low + dw;
410
411 if ih >= 0
412 && ih < in_height as i32
413 && iw >= 0
414 && iw < in_width as i32
415 {
416 let weight = if dh == 0 { 1.0 - wh } else { wh }
417 * if dw == 0 { 1.0 - ww } else { ww };
418
419 let in_idx = ((b * channels + c) * in_height + ih as usize)
420 * in_width
421 + iw as usize;
422 value += weight * input_data[in_idx];
423 }
424 }
425 }
426
427 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
428 output_data[out_idx] = value;
429 }
430 }
431 }
432 }
433 }
434 InterpolationMode::Bicubic | InterpolationMode::Area => {
435 // Simplified implementation - use bilinear for now
436 return resize(input, size, InterpolationMode::Bilinear, antialias);
437 }
438 }
439
440 let mut output_shape = batch_dims;
441 output_shape.extend_from_slice(&[channels, out_height, out_width]);
442
443 Tensor::from_data(output_data, output_shape, input.device())
444}
445
446/// Interpolation modes for resizing
447#[derive(Debug, Clone, Copy)]
448pub enum InterpolationMode {
449 Nearest,
450 Bilinear,
451 Bicubic,
452 Area,
453}
454
455/// Apply Gaussian blur to image tensor
456pub fn gaussian_blur(input: &Tensor, kernel_size: usize, sigma: f32) -> TorshResult<Tensor> {
457 let shape = input.shape();
458 if shape.ndim() < 3 {
459 return Err(TorshError::invalid_argument_with_context(
460 "Input tensor must have at least 3 dimensions (C, H, W)",
461 "gaussian_blur",
462 ));
463 }
464
465 // Create Gaussian kernel
466 let radius = kernel_size / 2;
467 let mut kernel = vec![0.0f32; kernel_size * kernel_size];
468 let mut sum = 0.0f32;
469
470 for i in 0..kernel_size {
471 for j in 0..kernel_size {
472 let x = i as i32 - radius as i32;
473 let y = j as i32 - radius as i32;
474 let val = (-((x * x + y * y) as f32) / (2.0 * sigma * sigma)).exp();
475 kernel[i * kernel_size + j] = val;
476 sum += val;
477 }
478 }
479
480 // Normalize kernel
481 for val in &mut kernel {
482 *val /= sum;
483 }
484
485 // Apply convolution with the Gaussian kernel
486 apply_convolution(input, &kernel, kernel_size, 1, radius)
487}
488
489/// Apply Sobel edge detection
490pub fn sobel_filter(input: &Tensor, direction: SobelDirection) -> TorshResult<Tensor> {
491 let kernel = match direction {
492 SobelDirection::X => vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0],
493 SobelDirection::Y => vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0],
494 SobelDirection::Both => {
495 // For both directions, compute magnitude
496 let x_result = sobel_filter(input, SobelDirection::X)?;
497 let y_result = sobel_filter(input, SobelDirection::Y)?;
498 return compute_gradient_magnitude(&x_result, &y_result);
499 }
500 };
501
502 apply_convolution(input, &kernel, 3, 1, 1)
503}
504
505/// Sobel filter directions
506#[derive(Debug, Clone, Copy)]
507pub enum SobelDirection {
508 X,
509 Y,
510 Both,
511}
512
513/// Apply Laplacian filter for edge detection
514pub fn laplacian_filter(input: &Tensor) -> TorshResult<Tensor> {
515 let kernel = vec![0.0, -1.0, 0.0, -1.0, 4.0, -1.0, 0.0, -1.0, 0.0];
516
517 apply_convolution(input, &kernel, 3, 1, 1)
518}
519
520/// Apply morphological erosion
521pub fn erosion(input: &Tensor, kernel_size: usize, iterations: usize) -> TorshResult<Tensor> {
522 let mut result = input.clone();
523
524 for _ in 0..iterations {
525 result = apply_morphological_op(&result, kernel_size, MorphOp::Erosion)?;
526 }
527
528 Ok(result)
529}
530
531/// Apply morphological dilation
532pub fn dilation(input: &Tensor, kernel_size: usize, iterations: usize) -> TorshResult<Tensor> {
533 let mut result = input.clone();
534
535 for _ in 0..iterations {
536 result = apply_morphological_op(&result, kernel_size, MorphOp::Dilation)?;
537 }
538
539 Ok(result)
540}
541
542/// Apply morphological opening (erosion followed by dilation)
543pub fn opening(input: &Tensor, kernel_size: usize) -> TorshResult<Tensor> {
544 let eroded = erosion(input, kernel_size, 1)?;
545 dilation(&eroded, kernel_size, 1)
546}
547
548/// Apply morphological closing (dilation followed by erosion)
549pub fn closing(input: &Tensor, kernel_size: usize) -> TorshResult<Tensor> {
550 let dilated = dilation(input, kernel_size, 1)?;
551 erosion(&dilated, kernel_size, 1)
552}
553
554/// Convert RGB to HSV color space
555pub fn rgb_to_hsv(input: &Tensor) -> TorshResult<Tensor> {
556 let shape = input.shape();
557 if shape.ndim() < 3 {
558 return Err(TorshError::invalid_argument_with_context(
559 "Input tensor must have at least 3 dimensions (C, H, W)",
560 "rgb_to_hsv",
561 ));
562 }
563
564 let dims = shape.dims();
565 if dims[dims.len() - 3] != 3 {
566 return Err(TorshError::invalid_argument_with_context(
567 "Input tensor must have 3 channels for RGB",
568 "rgb_to_hsv",
569 ));
570 }
571
572 let input_data = input.to_vec()?;
573 let mut output_data = vec![0.0f32; input_data.len()];
574
575 let batch_size = dims[..dims.len() - 3].iter().product::<usize>();
576 let height = dims[dims.len() - 2];
577 let width = dims[dims.len() - 1];
578
579 for b in 0..batch_size {
580 for h in 0..height {
581 for w in 0..width {
582 let r_idx = ((b * 3 + 0) * height + h) * width + w;
583 let g_idx = ((b * 3 + 1) * height + h) * width + w;
584 let b_idx = ((b * 3 + 2) * height + h) * width + w;
585
586 let r = input_data[r_idx];
587 let g = input_data[g_idx];
588 let b_val = input_data[b_idx];
589
590 let max_val = r.max(g).max(b_val);
591 let min_val = r.min(g).min(b_val);
592 let delta = max_val - min_val;
593
594 // Value
595 let v = max_val;
596
597 // Saturation
598 let s = if max_val == 0.0 { 0.0 } else { delta / max_val };
599
600 // Hue
601 let h_val = if delta == 0.0 {
602 0.0
603 } else if max_val == r {
604 60.0 * (((g - b_val) / delta) % 6.0)
605 } else if max_val == g {
606 60.0 * ((b_val - r) / delta + 2.0)
607 } else {
608 60.0 * ((r - g) / delta + 4.0)
609 };
610
611 output_data[r_idx] = h_val / 360.0; // Normalize hue to [0, 1]
612 output_data[g_idx] = s;
613 output_data[b_idx] = v;
614 }
615 }
616 }
617
618 Tensor::from_data(output_data, dims.to_vec(), input.device())
619}
620
621/// Convert HSV to RGB color space
622pub fn hsv_to_rgb(input: &Tensor) -> TorshResult<Tensor> {
623 let shape = input.shape();
624 if shape.ndim() < 3 {
625 return Err(TorshError::invalid_argument_with_context(
626 "Input tensor must have at least 3 dimensions (C, H, W)",
627 "hsv_to_rgb",
628 ));
629 }
630
631 let dims = shape.dims();
632 if dims[dims.len() - 3] != 3 {
633 return Err(TorshError::invalid_argument_with_context(
634 "Input tensor must have 3 channels for HSV",
635 "hsv_to_rgb",
636 ));
637 }
638
639 let input_data = input.to_vec()?;
640 let mut output_data = vec![0.0f32; input_data.len()];
641
642 let batch_size = dims[..dims.len() - 3].iter().product::<usize>();
643 let height = dims[dims.len() - 2];
644 let width = dims[dims.len() - 1];
645
646 for b in 0..batch_size {
647 for h in 0..height {
648 for w in 0..width {
649 let h_idx = ((b * 3 + 0) * height + h) * width + w;
650 let s_idx = ((b * 3 + 1) * height + h) * width + w;
651 let v_idx = ((b * 3 + 2) * height + h) * width + w;
652
653 let h_val = input_data[h_idx] * 360.0; // Denormalize hue
654 let s = input_data[s_idx];
655 let v = input_data[v_idx];
656
657 let c = v * s;
658 let x = c * (1.0 - ((h_val / 60.0) % 2.0 - 1.0).abs());
659 let m = v - c;
660
661 let (r_prime, g_prime, b_prime) = if h_val < 60.0 {
662 (c, x, 0.0)
663 } else if h_val < 120.0 {
664 (x, c, 0.0)
665 } else if h_val < 180.0 {
666 (0.0, c, x)
667 } else if h_val < 240.0 {
668 (0.0, x, c)
669 } else if h_val < 300.0 {
670 (x, 0.0, c)
671 } else {
672 (c, 0.0, x)
673 };
674
675 output_data[h_idx] = r_prime + m;
676 output_data[s_idx] = g_prime + m;
677 output_data[v_idx] = b_prime + m;
678 }
679 }
680 }
681
682 Tensor::from_data(output_data, dims.to_vec(), input.device())
683}
684
685/// Apply affine transformation to image
686pub fn affine_transform(
687 input: &Tensor,
688 matrix: &[f32; 6], // [a, b, c, d, e, f] for transformation [[a, b, c], [d, e, f]]
689 output_size: Option<(usize, usize)>,
690 fill_value: f32,
691) -> TorshResult<Tensor> {
692 let shape = input.shape();
693 if shape.ndim() < 3 {
694 return Err(TorshError::invalid_argument_with_context(
695 "Input tensor must have at least 3 dimensions (C, H, W)",
696 "affine_transform",
697 ));
698 }
699
700 let dims = shape.dims();
701 let channels = dims[dims.len() - 3];
702 let in_height = dims[dims.len() - 2];
703 let in_width = dims[dims.len() - 1];
704
705 let (out_height, out_width) = output_size.unwrap_or((in_height, in_width));
706
707 // Compute inverse transformation matrix for backward mapping
708 let [a, b, c, d, e, f] = *matrix;
709 let det = a * e - b * d;
710
711 if det.abs() < 1e-6 {
712 return Err(TorshError::invalid_argument_with_context(
713 "Affine transformation matrix is singular",
714 "affine_transform",
715 ));
716 }
717
718 let inv_det = 1.0 / det;
719 let inv_a = e * inv_det;
720 let inv_b = -b * inv_det;
721 let inv_c = (b * f - c * e) * inv_det;
722 let inv_d = -d * inv_det;
723 let inv_e = a * inv_det;
724 let inv_f = (c * d - a * f) * inv_det;
725
726 let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
727 let batch_size = batch_dims.iter().product::<usize>();
728
729 let input_data = input.to_vec()?;
730 let mut output_data = vec![fill_value; batch_size * channels * out_height * out_width];
731
732 for b in 0..batch_size {
733 for c in 0..channels {
734 for oh in 0..out_height {
735 for ow in 0..out_width {
736 // Apply inverse transformation
737 let x_out = ow as f32;
738 let y_out = oh as f32;
739
740 let x_in = inv_a * x_out + inv_b * y_out + inv_c;
741 let y_in = inv_d * x_out + inv_e * y_out + inv_f;
742
743 // Bilinear interpolation
744 if x_in >= 0.0
745 && x_in < in_width as f32 - 1.0
746 && y_in >= 0.0
747 && y_in < in_height as f32 - 1.0
748 {
749 let x0 = x_in.floor() as usize;
750 let y0 = y_in.floor() as usize;
751 let x1 = x0 + 1;
752 let y1 = y0 + 1;
753
754 let wx = x_in - x0 as f32;
755 let wy = y_in - y0 as f32;
756
757 let idx00 = ((b * channels + c) * in_height + y0) * in_width + x0;
758 let idx01 = ((b * channels + c) * in_height + y0) * in_width + x1;
759 let idx10 = ((b * channels + c) * in_height + y1) * in_width + x0;
760 let idx11 = ((b * channels + c) * in_height + y1) * in_width + x1;
761
762 let val = (1.0 - wx) * (1.0 - wy) * input_data[idx00]
763 + wx * (1.0 - wy) * input_data[idx01]
764 + (1.0 - wx) * wy * input_data[idx10]
765 + wx * wy * input_data[idx11];
766
767 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
768 output_data[out_idx] = val;
769 }
770 }
771 }
772 }
773 }
774
775 let mut output_shape = batch_dims;
776 output_shape.extend_from_slice(&[channels, out_height, out_width]);
777
778 Tensor::from_data(output_data, output_shape, input.device())
779}
780
781// Helper functions
782
783/// Apply convolution with a given kernel
784fn apply_convolution(
785 input: &Tensor,
786 kernel: &[f32],
787 kernel_size: usize,
788 stride: usize,
789 padding: usize,
790) -> TorshResult<Tensor> {
791 let shape = input.shape();
792 let dims = shape.dims();
793
794 let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
795 let batch_size = batch_dims.iter().product::<usize>();
796 let channels = dims[dims.len() - 3];
797 let in_height = dims[dims.len() - 2];
798 let in_width = dims[dims.len() - 1];
799
800 let out_height = (in_height + 2 * padding - kernel_size) / stride + 1;
801 let out_width = (in_width + 2 * padding - kernel_size) / stride + 1;
802
803 let input_data = input.to_vec()?;
804 let mut output_data = vec![0.0f32; batch_size * channels * out_height * out_width];
805
806 for b in 0..batch_size {
807 for c in 0..channels {
808 for oh in 0..out_height {
809 for ow in 0..out_width {
810 let mut sum = 0.0f32;
811
812 for kh in 0..kernel_size {
813 for kw in 0..kernel_size {
814 let ih = oh * stride + kh;
815 let iw = ow * stride + kw;
816
817 if ih >= padding
818 && ih < in_height + padding
819 && iw >= padding
820 && iw < in_width + padding
821 {
822 let real_ih = ih - padding;
823 let real_iw = iw - padding;
824
825 if real_ih < in_height && real_iw < in_width {
826 let in_idx = ((b * channels + c) * in_height + real_ih)
827 * in_width
828 + real_iw;
829 let kernel_idx = kh * kernel_size + kw;
830 sum += input_data[in_idx] * kernel[kernel_idx];
831 }
832 }
833 }
834 }
835
836 let out_idx = ((b * channels + c) * out_height + oh) * out_width + ow;
837 output_data[out_idx] = sum;
838 }
839 }
840 }
841 }
842
843 let mut output_shape = batch_dims;
844 output_shape.extend_from_slice(&[channels, out_height, out_width]);
845
846 Tensor::from_data(output_data, output_shape, input.device())
847}
848
849/// Morphological operation types
850#[derive(Debug, Clone, Copy)]
851enum MorphOp {
852 Erosion,
853 Dilation,
854}
855
856/// Apply morphological operation
857fn apply_morphological_op(input: &Tensor, kernel_size: usize, op: MorphOp) -> TorshResult<Tensor> {
858 let shape = input.shape();
859 let dims = shape.dims();
860
861 let batch_dims: Vec<usize> = dims[..dims.len() - 3].to_vec();
862 let batch_size = batch_dims.iter().product::<usize>();
863 let channels = dims[dims.len() - 3];
864 let height = dims[dims.len() - 2];
865 let width = dims[dims.len() - 1];
866
867 let radius = kernel_size / 2;
868 let input_data = input.to_vec()?;
869 let mut output_data = vec![0.0f32; input_data.len()];
870
871 for b in 0..batch_size {
872 for c in 0..channels {
873 for h in 0..height {
874 for w in 0..width {
875 let mut result = match op {
876 MorphOp::Erosion => f32::INFINITY,
877 MorphOp::Dilation => f32::NEG_INFINITY,
878 };
879
880 for kh in 0..kernel_size {
881 for kw in 0..kernel_size {
882 let ih = h as i32 + kh as i32 - radius as i32;
883 let iw = w as i32 + kw as i32 - radius as i32;
884
885 if ih >= 0 && ih < height as i32 && iw >= 0 && iw < width as i32 {
886 let in_idx = ((b * channels + c) * height + ih as usize) * width
887 + iw as usize;
888 let val = input_data[in_idx];
889
890 match op {
891 MorphOp::Erosion => result = result.min(val),
892 MorphOp::Dilation => result = result.max(val),
893 }
894 }
895 }
896 }
897
898 let out_idx = ((b * channels + c) * height + h) * width + w;
899 output_data[out_idx] = result;
900 }
901 }
902 }
903 }
904
905 Tensor::from_data(output_data, dims.to_vec(), input.device())
906}
907
908/// Compute gradient magnitude from X and Y gradients
909fn compute_gradient_magnitude(grad_x: &Tensor, grad_y: &Tensor) -> TorshResult<Tensor> {
910 let grad_x_data = grad_x.to_vec()?;
911 let grad_y_data = grad_y.to_vec()?;
912
913 let magnitude_data: Vec<f32> = grad_x_data
914 .iter()
915 .zip(grad_y_data.iter())
916 .map(|(&gx, &gy)| (gx * gx + gy * gy).sqrt())
917 .collect();
918
919 Tensor::from_data(
920 magnitude_data,
921 grad_x.shape().dims().to_vec(),
922 grad_x.device(),
923 )
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929 use torsh_tensor::creation::ones;
930
931 #[test]
932 fn test_resize_nearest() {
933 let input = ones(&[1, 3, 4, 4]).unwrap(); // Batch=1, Channels=3, Height=4, Width=4
934 let result = resize(&input, (2, 2), InterpolationMode::Nearest, false).unwrap();
935 assert_eq!(result.shape().dims(), &[1, 3, 2, 2]);
936 }
937
938 #[test]
939 fn test_gaussian_blur() {
940 let input = ones(&[1, 3, 5, 5]).unwrap();
941 let result = gaussian_blur(&input, 3, 1.0).unwrap();
942 assert_eq!(result.shape().dims(), &[1, 3, 5, 5]); // Size maintained due to padding
943 }
944
945 #[test]
946 fn test_sobel_filter() {
947 let input = ones(&[1, 1, 5, 5]).unwrap();
948 let result = sobel_filter(&input, SobelDirection::X).unwrap();
949 assert_eq!(result.shape().dims(), &[1, 1, 5, 5]); // Size maintained due to padding
950 }
951
952 #[test]
953 fn test_rgb_to_hsv_conversion() {
954 let input = ones(&[1, 3, 2, 2]).unwrap(); // RGB image
955 let hsv = rgb_to_hsv(&input).unwrap();
956 let rgb_back = hsv_to_rgb(&hsv).unwrap();
957 assert_eq!(rgb_back.shape().dims(), &[1, 3, 2, 2]);
958 }
959
960 #[test]
961 fn test_morphological_operations() {
962 let input = ones(&[1, 1, 5, 5]).unwrap();
963 let eroded = erosion(&input, 3, 1).unwrap();
964 let dilated = dilation(&input, 3, 1).unwrap();
965 assert_eq!(eroded.shape().dims(), &[1, 1, 5, 5]);
966 assert_eq!(dilated.shape().dims(), &[1, 1, 5, 5]);
967 }
968}