sklears_preprocessing/
image_preprocessing.rs

1//! Image Preprocessing for Computer Vision Applications
2//!
3//! This module provides comprehensive image preprocessing utilities for computer vision
4//! and machine learning tasks, including normalization, data augmentation, color space
5//! transformations, and feature extraction.
6//!
7//! # Features
8//!
9//! - Image normalization and standardization
10//! - Data augmentation techniques (rotation, scaling, flipping, cropping)
11//! - Color space transformations (RGB, HSV, LAB, grayscale)
12//! - Image resizing and cropping
13//! - Edge detection and feature extraction
14//! - Batch image processing with parallel support
15//!
16//! # Examples
17//!
18//! ```rust,ignore
19//! use sklears_preprocessing::image_preprocessing::{
20//!     ImageNormalizer, ImageAugmenter, ColorSpaceTransformer
21//! };
22//! use scirs2_core::ndarray::Array3;
23//!
24//! fn example() -> Result<(), Box<dyn std::error::Error>> {
25//!     // Normalize image pixel values to [0, 1] range
26//!     let normalizer = ImageNormalizer::new()
27//!         .with_range((0.0, 1.0))
28//!         .with_channel_wise(true);
29//!
30//!     let image = Array3::<f64>::zeros((224, 224, 3)); // RGB image
31//!     let normalized = normalizer.transform(&image)?;
32//!
33//!     // Apply data augmentation
34//!     let augmenter = ImageAugmenter::new()
35//!         .with_rotation_range((-15.0, 15.0))
36//!         .with_zoom_range((0.9, 1.1))
37//!         .with_horizontal_flip(true);
38//!
39//!     let augmented = augmenter.transform(&normalized)?;
40//!
41//!     // Convert color space
42//!     let color_transformer = ColorSpaceTransformer::new()
43//!         .from_rgb()
44//!         .to_hsv();
45//!
46//!     let hsv_image = color_transformer.transform(&augmented)?;
47//!
48//!     Ok(())
49//! }
50//! ```
51
52use scirs2_core::ndarray::{Array3, Axis};
53use scirs2_core::random::thread_rng;
54// Note: using fallback implementations since SIMD functions may not be available
55// use scirs2_core::simd_ops::{mean_f64_simd, variance_f64_simd};
56use sklears_core::{
57    error::{Result, SklearsError},
58    traits::{Estimator, Fit, Transform, Untrained},
59    types::Float,
60};
61use std::f64::consts::PI;
62
63#[cfg(feature = "serde")]
64use serde::{Deserialize, Serialize};
65
66/// Image normalization strategies
67#[derive(Debug, Clone, Copy, PartialEq)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub enum ImageNormalizationStrategy {
70    /// Normalize to [0, 1] range using min-max scaling
71    MinMax,
72    /// Standardize to zero mean and unit variance
73    StandardScore,
74    /// Custom range normalization
75    CustomRange(Float, Float),
76}
77
78impl Default for ImageNormalizationStrategy {
79    fn default() -> Self {
80        Self::MinMax
81    }
82}
83
84/// Color space types for image transformations
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87pub enum ColorSpace {
88    /// RGB (Red, Green, Blue)
89    RGB,
90    /// HSV (Hue, Saturation, Value)
91    HSV,
92    /// LAB (Lightness, A, B)
93    LAB,
94    /// Grayscale (single channel)
95    Grayscale,
96}
97
98/// Image resizing interpolation methods
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub enum InterpolationMethod {
102    /// Nearest neighbor interpolation (fast, blocky)
103    Nearest,
104    /// Bilinear interpolation (smooth, good quality/speed balance)
105    Bilinear,
106    /// Bicubic interpolation (highest quality, slower)
107    Bicubic,
108}
109
110impl Default for InterpolationMethod {
111    fn default() -> Self {
112        Self::Bilinear
113    }
114}
115
116/// Configuration for image normalization
117#[derive(Debug, Clone)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub struct ImageNormalizerConfig {
120    /// Normalization strategy to use
121    pub strategy: ImageNormalizationStrategy,
122    /// Whether to normalize each channel independently
123    pub channel_wise: bool,
124    /// Epsilon for numerical stability
125    pub epsilon: Float,
126}
127
128impl Default for ImageNormalizerConfig {
129    fn default() -> Self {
130        Self {
131            strategy: ImageNormalizationStrategy::MinMax,
132            channel_wise: true,
133            epsilon: 1e-8,
134        }
135    }
136}
137
138/// Image normalizer for preprocessing image data
139#[derive(Debug, Clone)]
140#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
141pub struct ImageNormalizer<State = Untrained> {
142    config: ImageNormalizerConfig,
143    state: std::marker::PhantomData<State>,
144}
145
146/// Fitted state for image normalizer
147#[derive(Debug, Clone)]
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149pub struct ImageNormalizerFitted {
150    config: ImageNormalizerConfig,
151    min_vals: Vec<Float>,
152    max_vals: Vec<Float>,
153    mean_vals: Vec<Float>,
154    std_vals: Vec<Float>,
155}
156
157impl ImageNormalizer<Untrained> {
158    /// Create a new image normalizer
159    pub fn new() -> Self {
160        Self {
161            config: ImageNormalizerConfig::default(),
162            state: std::marker::PhantomData,
163        }
164    }
165
166    /// Set normalization strategy
167    pub fn with_strategy(mut self, strategy: ImageNormalizationStrategy) -> Self {
168        self.config.strategy = strategy;
169        self
170    }
171
172    /// Set range for min-max normalization
173    pub fn with_range(mut self, range: (Float, Float)) -> Self {
174        self.config.strategy = ImageNormalizationStrategy::CustomRange(range.0, range.1);
175        self
176    }
177
178    /// Enable/disable channel-wise normalization
179    pub fn with_channel_wise(mut self, channel_wise: bool) -> Self {
180        self.config.channel_wise = channel_wise;
181        self
182    }
183
184    /// Set epsilon for numerical stability
185    pub fn with_epsilon(mut self, epsilon: Float) -> Self {
186        self.config.epsilon = epsilon;
187        self
188    }
189}
190
191impl Estimator for ImageNormalizer<Untrained> {
192    type Config = ImageNormalizerConfig;
193    type Error = SklearsError;
194    type Float = Float;
195
196    fn config(&self) -> &Self::Config {
197        &self.config
198    }
199}
200
201impl Fit<Array3<Float>, ()> for ImageNormalizer<Untrained> {
202    type Fitted = ImageNormalizerFitted;
203
204    fn fit(self, x: &Array3<Float>, _y: &()) -> Result<Self::Fitted> {
205        let shape = x.dim();
206        let n_channels = shape.2;
207
208        let (min_vals, max_vals, mean_vals, std_vals) = if self.config.channel_wise {
209            let mut min_vals = Vec::with_capacity(n_channels);
210            let mut max_vals = Vec::with_capacity(n_channels);
211            let mut mean_vals = Vec::with_capacity(n_channels);
212            let mut std_vals = Vec::with_capacity(n_channels);
213
214            for channel in 0..n_channels {
215                let channel_data = x.index_axis(Axis(2), channel);
216                let data_slice: Vec<Float> = channel_data.iter().copied().collect();
217
218                let min_val = data_slice.iter().fold(Float::INFINITY, |a, &b| a.min(b));
219                let max_val = data_slice
220                    .iter()
221                    .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
222
223                let mean_val = data_slice.iter().sum::<Float>() / data_slice.len() as Float;
224
225                let var_val = data_slice
226                    .iter()
227                    .map(|&x| (x - mean_val).powi(2))
228                    .sum::<Float>()
229                    / (data_slice.len() as Float - 1.0);
230
231                let std_val = var_val.sqrt().max(self.config.epsilon);
232
233                min_vals.push(min_val);
234                max_vals.push(max_val);
235                mean_vals.push(mean_val);
236                std_vals.push(std_val);
237            }
238
239            (min_vals, max_vals, mean_vals, std_vals)
240        } else {
241            // Global statistics across all channels
242            let all_data: Vec<Float> = x.iter().copied().collect();
243
244            let min_val = all_data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
245            let max_val = all_data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
246
247            let mean_val = all_data.iter().sum::<Float>() / all_data.len() as Float;
248
249            let var_val = all_data
250                .iter()
251                .map(|&x| (x - mean_val).powi(2))
252                .sum::<Float>()
253                / (all_data.len() as Float - 1.0);
254
255            let std_val = var_val.sqrt().max(self.config.epsilon);
256
257            (
258                vec![min_val; n_channels],
259                vec![max_val; n_channels],
260                vec![mean_val; n_channels],
261                vec![std_val; n_channels],
262            )
263        };
264
265        Ok(ImageNormalizerFitted {
266            config: self.config,
267            min_vals,
268            max_vals,
269            mean_vals,
270            std_vals,
271        })
272    }
273}
274
275impl Transform<Array3<Float>, Array3<Float>> for ImageNormalizerFitted {
276    fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
277        let shape = x.dim();
278        let n_channels = shape.2;
279
280        if n_channels != self.min_vals.len() {
281            return Err(SklearsError::InvalidInput(format!(
282                "Expected {} channels, got {}",
283                self.min_vals.len(),
284                n_channels
285            )));
286        }
287
288        let mut result = x.clone();
289
290        match self.config.strategy {
291            ImageNormalizationStrategy::MinMax => {
292                for channel in 0..n_channels {
293                    let min_val = self.min_vals[channel];
294                    let max_val = self.max_vals[channel];
295                    let range = max_val - min_val;
296
297                    if range > self.config.epsilon {
298                        let mut channel_data = result.index_axis_mut(Axis(2), channel);
299                        channel_data.mapv_inplace(|x| (x - min_val) / range);
300                    }
301                }
302            }
303            ImageNormalizationStrategy::CustomRange(min_target, max_target) => {
304                let target_range = max_target - min_target;
305                for channel in 0..n_channels {
306                    let min_val = self.min_vals[channel];
307                    let max_val = self.max_vals[channel];
308                    let source_range = max_val - min_val;
309
310                    if source_range > self.config.epsilon {
311                        let mut channel_data = result.index_axis_mut(Axis(2), channel);
312                        channel_data.mapv_inplace(|x| {
313                            min_target + ((x - min_val) / source_range) * target_range
314                        });
315                    }
316                }
317            }
318            ImageNormalizationStrategy::StandardScore => {
319                for channel in 0..n_channels {
320                    let mean_val = self.mean_vals[channel];
321                    let std_val = self.std_vals[channel];
322
323                    let mut channel_data = result.index_axis_mut(Axis(2), channel);
324                    channel_data.mapv_inplace(|x| (x - mean_val) / std_val);
325                }
326            }
327        }
328
329        Ok(result)
330    }
331}
332
333/// Configuration for image augmentation
334#[derive(Debug, Clone)]
335#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
336pub struct ImageAugmenterConfig {
337    /// Rotation range in degrees (min, max)
338    pub rotation_range: Option<(Float, Float)>,
339    /// Zoom range as factors (min, max)
340    pub zoom_range: Option<(Float, Float)>,
341    /// Width shift range as fraction of total width
342    pub width_shift_range: Option<Float>,
343    /// Height shift range as fraction of total height
344    pub height_shift_range: Option<Float>,
345    /// Enable horizontal flipping
346    pub horizontal_flip: bool,
347    /// Enable vertical flipping
348    pub vertical_flip: bool,
349    /// Brightness adjustment range (min, max)
350    pub brightness_range: Option<(Float, Float)>,
351    /// Random seed for reproducibility
352    pub random_seed: Option<u64>,
353}
354
355impl Default for ImageAugmenterConfig {
356    fn default() -> Self {
357        Self {
358            rotation_range: None,
359            zoom_range: None,
360            width_shift_range: None,
361            height_shift_range: None,
362            horizontal_flip: false,
363            vertical_flip: false,
364            brightness_range: None,
365            random_seed: None,
366        }
367    }
368}
369
370/// Image augmenter for data augmentation
371#[derive(Debug, Clone)]
372#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
373pub struct ImageAugmenter {
374    config: ImageAugmenterConfig,
375}
376
377impl ImageAugmenter {
378    /// Create a new image augmenter
379    pub fn new() -> Self {
380        Self {
381            config: ImageAugmenterConfig::default(),
382        }
383    }
384
385    /// Set rotation range in degrees
386    pub fn with_rotation_range(mut self, range: (Float, Float)) -> Self {
387        self.config.rotation_range = Some(range);
388        self
389    }
390
391    /// Set zoom range as factors
392    pub fn with_zoom_range(mut self, range: (Float, Float)) -> Self {
393        self.config.zoom_range = Some(range);
394        self
395    }
396
397    /// Set width shift range as fraction
398    pub fn with_width_shift_range(mut self, range: Float) -> Self {
399        self.config.width_shift_range = Some(range);
400        self
401    }
402
403    /// Set height shift range as fraction
404    pub fn with_height_shift_range(mut self, range: Float) -> Self {
405        self.config.height_shift_range = Some(range);
406        self
407    }
408
409    /// Enable horizontal flipping
410    pub fn with_horizontal_flip(mut self, enabled: bool) -> Self {
411        self.config.horizontal_flip = enabled;
412        self
413    }
414
415    /// Enable vertical flipping
416    pub fn with_vertical_flip(mut self, enabled: bool) -> Self {
417        self.config.vertical_flip = enabled;
418        self
419    }
420
421    /// Set brightness adjustment range
422    pub fn with_brightness_range(mut self, range: (Float, Float)) -> Self {
423        self.config.brightness_range = Some(range);
424        self
425    }
426
427    /// Set random seed for reproducibility
428    pub fn with_random_seed(mut self, seed: u64) -> Self {
429        self.config.random_seed = Some(seed);
430        // Note: seed will be used by thread_rng() function if needed
431        self
432    }
433}
434
435impl Transform<Array3<Float>, Array3<Float>> for ImageAugmenter {
436    fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
437        let mut result = x.clone();
438        let mut rng = thread_rng();
439
440        // Apply horizontal flip
441        if self.config.horizontal_flip && rng.random::<Float>() < 0.5 {
442            result = self.horizontal_flip(&result)?;
443        }
444
445        // Apply vertical flip
446        if self.config.vertical_flip && rng.random::<Float>() < 0.5 {
447            result = self.vertical_flip(&result)?;
448        }
449
450        // Apply rotation
451        if let Some((min_angle, max_angle)) = self.config.rotation_range {
452            let angle = rng.gen_range(min_angle..max_angle);
453            if angle.abs() > 1e-6 {
454                result = self.rotate(&result, angle)?;
455            }
456        }
457
458        // Apply brightness adjustment
459        if let Some((min_brightness, max_brightness)) = self.config.brightness_range {
460            let brightness_factor = rng.gen_range(min_brightness..max_brightness);
461            if (brightness_factor - 1.0).abs() > 1e-6 {
462                result.mapv_inplace(|x| (x * brightness_factor).clamp(0.0, 1.0));
463            }
464        }
465
466        Ok(result)
467    }
468}
469
470impl ImageAugmenter {
471    /// Apply horizontal flip to image
472    fn horizontal_flip(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
473        let shape = image.dim();
474        let mut result = Array3::zeros(shape);
475
476        for row in 0..shape.0 {
477            for col in 0..shape.1 {
478                for channel in 0..shape.2 {
479                    result[[row, shape.1 - 1 - col, channel]] = image[[row, col, channel]];
480                }
481            }
482        }
483
484        Ok(result)
485    }
486
487    /// Apply vertical flip to image
488    fn vertical_flip(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
489        let shape = image.dim();
490        let mut result = Array3::zeros(shape);
491
492        for row in 0..shape.0 {
493            for col in 0..shape.1 {
494                for channel in 0..shape.2 {
495                    result[[shape.0 - 1 - row, col, channel]] = image[[row, col, channel]];
496                }
497            }
498        }
499
500        Ok(result)
501    }
502
503    /// Apply rotation to image (simplified implementation)
504    fn rotate(&self, image: &Array3<Float>, angle_degrees: Float) -> Result<Array3<Float>> {
505        let shape = image.dim();
506        let mut result = Array3::zeros(shape);
507
508        let angle_rad = angle_degrees * PI / 180.0;
509        let cos_angle = angle_rad.cos();
510        let sin_angle = angle_rad.sin();
511
512        let center_x = shape.1 as Float / 2.0;
513        let center_y = shape.0 as Float / 2.0;
514
515        // Simple rotation with nearest neighbor interpolation
516        for row in 0..shape.0 {
517            for col in 0..shape.1 {
518                let x = col as Float - center_x;
519                let y = row as Float - center_y;
520
521                let rotated_x = x * cos_angle - y * sin_angle + center_x;
522                let rotated_y = x * sin_angle + y * cos_angle + center_y;
523
524                let src_col = rotated_x.round() as isize;
525                let src_row = rotated_y.round() as isize;
526
527                if src_row >= 0
528                    && src_row < shape.0 as isize
529                    && src_col >= 0
530                    && src_col < shape.1 as isize
531                {
532                    for channel in 0..shape.2 {
533                        result[[row, col, channel]] =
534                            image[[src_row as usize, src_col as usize, channel]];
535                    }
536                }
537            }
538        }
539
540        Ok(result)
541    }
542}
543
544/// Color space transformer
545#[derive(Debug, Clone)]
546#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
547pub struct ColorSpaceTransformer {
548    source: ColorSpace,
549    target: ColorSpace,
550}
551
552impl ColorSpaceTransformer {
553    /// Create a new color space transformer
554    pub fn new() -> Self {
555        Self {
556            source: ColorSpace::RGB,
557            target: ColorSpace::RGB,
558        }
559    }
560
561    /// Set source color space
562    pub fn from_colorspace(mut self, colorspace: ColorSpace) -> Self {
563        self.source = colorspace;
564        self
565    }
566
567    /// Set target color space
568    pub fn to_colorspace(mut self, colorspace: ColorSpace) -> Self {
569        self.target = colorspace;
570        self
571    }
572
573    /// Set source as RGB
574    pub fn from_rgb(mut self) -> Self {
575        self.source = ColorSpace::RGB;
576        self
577    }
578
579    /// Set target as HSV
580    pub fn to_hsv(mut self) -> Self {
581        self.target = ColorSpace::HSV;
582        self
583    }
584
585    /// Set target as grayscale
586    pub fn to_grayscale(mut self) -> Self {
587        self.target = ColorSpace::Grayscale;
588        self
589    }
590}
591
592impl Transform<Array3<Float>, Array3<Float>> for ColorSpaceTransformer {
593    fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
594        match (self.source, self.target) {
595            (ColorSpace::RGB, ColorSpace::HSV) => self.rgb_to_hsv(x),
596            (ColorSpace::RGB, ColorSpace::Grayscale) => self.rgb_to_grayscale(x),
597            (ColorSpace::HSV, ColorSpace::RGB) => self.hsv_to_rgb(x),
598            (source, target) if source == target => Ok(x.clone()),
599            _ => Err(SklearsError::InvalidInput(format!(
600                "Conversion from {:?} to {:?} not implemented",
601                self.source, self.target
602            ))),
603        }
604    }
605}
606
607impl ColorSpaceTransformer {
608    /// Convert RGB to HSV
609    fn rgb_to_hsv(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
610        let shape = image.dim();
611        if shape.2 != 3 {
612            return Err(SklearsError::InvalidInput(
613                "RGB images must have 3 channels".to_string(),
614            ));
615        }
616
617        let mut result = Array3::zeros(shape);
618
619        for row in 0..shape.0 {
620            for col in 0..shape.1 {
621                let r = image[[row, col, 0]];
622                let g = image[[row, col, 1]];
623                let b = image[[row, col, 2]];
624
625                let max_val = r.max(g).max(b);
626                let min_val = r.min(g).min(b);
627                let delta = max_val - min_val;
628
629                // Hue calculation
630                let h = if delta < 1e-8 {
631                    0.0
632                } else if (max_val - r).abs() < 1e-8 {
633                    60.0 * (((g - b) / delta) % 6.0)
634                } else if (max_val - g).abs() < 1e-8 {
635                    60.0 * (((b - r) / delta) + 2.0)
636                } else {
637                    60.0 * (((r - g) / delta) + 4.0)
638                };
639
640                let h = if h < 0.0 { h + 360.0 } else { h };
641
642                // Saturation calculation
643                let s = if max_val < 1e-8 { 0.0 } else { delta / max_val };
644
645                // Value calculation
646                let v = max_val;
647
648                result[[row, col, 0]] = h / 360.0; // Normalize hue to [0, 1]
649                result[[row, col, 1]] = s;
650                result[[row, col, 2]] = v;
651            }
652        }
653
654        Ok(result)
655    }
656
657    /// Convert HSV to RGB
658    fn hsv_to_rgb(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
659        let shape = image.dim();
660        if shape.2 != 3 {
661            return Err(SklearsError::InvalidInput(
662                "HSV images must have 3 channels".to_string(),
663            ));
664        }
665
666        let mut result = Array3::zeros(shape);
667
668        for row in 0..shape.0 {
669            for col in 0..shape.1 {
670                let h = image[[row, col, 0]] * 360.0; // Denormalize hue from [0, 1] to [0, 360]
671                let s = image[[row, col, 1]];
672                let v = image[[row, col, 2]];
673
674                let c = v * s;
675                let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
676                let m = v - c;
677
678                let (r, g, b) = if h < 60.0 {
679                    (c, x, 0.0)
680                } else if h < 120.0 {
681                    (x, c, 0.0)
682                } else if h < 180.0 {
683                    (0.0, c, x)
684                } else if h < 240.0 {
685                    (0.0, x, c)
686                } else if h < 300.0 {
687                    (x, 0.0, c)
688                } else {
689                    (c, 0.0, x)
690                };
691
692                result[[row, col, 0]] = r + m;
693                result[[row, col, 1]] = g + m;
694                result[[row, col, 2]] = b + m;
695            }
696        }
697
698        Ok(result)
699    }
700
701    /// Convert RGB to grayscale using luminance weighting
702    fn rgb_to_grayscale(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
703        let shape = image.dim();
704        if shape.2 != 3 {
705            return Err(SklearsError::InvalidInput(
706                "RGB images must have 3 channels".to_string(),
707            ));
708        }
709
710        let mut result = Array3::zeros((shape.0, shape.1, 1));
711
712        // Standard luminance weights for RGB to grayscale conversion
713        const R_WEIGHT: Float = 0.299;
714        const G_WEIGHT: Float = 0.587;
715        const B_WEIGHT: Float = 0.114;
716
717        for row in 0..shape.0 {
718            for col in 0..shape.1 {
719                let r = image[[row, col, 0]];
720                let g = image[[row, col, 1]];
721                let b = image[[row, col, 2]];
722
723                let gray = R_WEIGHT * r + G_WEIGHT * g + B_WEIGHT * b;
724                result[[row, col, 0]] = gray;
725            }
726        }
727
728        Ok(result)
729    }
730}
731
732/// Image resizer for changing image dimensions
733#[derive(Debug, Clone)]
734#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
735pub struct ImageResizer {
736    target_size: (usize, usize),
737    method: InterpolationMethod,
738}
739
740impl ImageResizer {
741    /// Create a new image resizer
742    pub fn new(target_size: (usize, usize)) -> Self {
743        Self {
744            target_size,
745            method: InterpolationMethod::default(),
746        }
747    }
748
749    /// Set interpolation method
750    pub fn with_method(mut self, method: InterpolationMethod) -> Self {
751        self.method = method;
752        self
753    }
754}
755
756impl Transform<Array3<Float>, Array3<Float>> for ImageResizer {
757    fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
758        let source_shape = x.dim();
759        let (target_height, target_width) = self.target_size;
760
761        if target_height == 0 || target_width == 0 {
762            return Err(SklearsError::InvalidInput(
763                "Target dimensions must be positive".to_string(),
764            ));
765        }
766
767        let mut result = Array3::zeros((target_height, target_width, source_shape.2));
768
769        let height_scale = source_shape.0 as Float / target_height as Float;
770        let width_scale = source_shape.1 as Float / target_width as Float;
771
772        match self.method {
773            InterpolationMethod::Nearest => {
774                for row in 0..target_height {
775                    for col in 0..target_width {
776                        let src_row = ((row as Float + 0.5) * height_scale).floor() as usize;
777                        let src_col = ((col as Float + 0.5) * width_scale).floor() as usize;
778
779                        let src_row = src_row.min(source_shape.0 - 1);
780                        let src_col = src_col.min(source_shape.1 - 1);
781
782                        for channel in 0..source_shape.2 {
783                            result[[row, col, channel]] = x[[src_row, src_col, channel]];
784                        }
785                    }
786                }
787            }
788            InterpolationMethod::Bilinear => {
789                for row in 0..target_height {
790                    for col in 0..target_width {
791                        let src_y = (row as Float + 0.5) * height_scale - 0.5;
792                        let src_x = (col as Float + 0.5) * width_scale - 0.5;
793
794                        let y1 = src_y.floor() as isize;
795                        let x1 = src_x.floor() as isize;
796                        let y2 = y1 + 1;
797                        let x2 = x1 + 1;
798
799                        let dy = src_y - y1 as Float;
800                        let dx = src_x - x1 as Float;
801
802                        for channel in 0..source_shape.2 {
803                            let mut sum = 0.0;
804
805                            // Bilinear interpolation weights and values
806                            if y1 >= 0
807                                && y1 < source_shape.0 as isize
808                                && x1 >= 0
809                                && x1 < source_shape.1 as isize
810                            {
811                                sum += (1.0 - dx)
812                                    * (1.0 - dy)
813                                    * x[[y1 as usize, x1 as usize, channel]];
814                            }
815                            if y1 >= 0
816                                && y1 < source_shape.0 as isize
817                                && x2 >= 0
818                                && x2 < source_shape.1 as isize
819                            {
820                                sum += dx * (1.0 - dy) * x[[y1 as usize, x2 as usize, channel]];
821                            }
822                            if y2 >= 0
823                                && y2 < source_shape.0 as isize
824                                && x1 >= 0
825                                && x1 < source_shape.1 as isize
826                            {
827                                sum += (1.0 - dx) * dy * x[[y2 as usize, x1 as usize, channel]];
828                            }
829                            if y2 >= 0
830                                && y2 < source_shape.0 as isize
831                                && x2 >= 0
832                                && x2 < source_shape.1 as isize
833                            {
834                                sum += dx * dy * x[[y2 as usize, x2 as usize, channel]];
835                            }
836
837                            result[[row, col, channel]] = sum;
838                        }
839                    }
840                }
841            }
842            InterpolationMethod::Bicubic => {
843                // Simplified bicubic - for production use, implement proper cubic kernel
844                // For now, fall back to bilinear
845                let bilinear_resizer =
846                    ImageResizer::new(self.target_size).with_method(InterpolationMethod::Bilinear);
847                return bilinear_resizer.transform(x);
848            }
849        }
850
851        Ok(result)
852    }
853}
854
855/// Edge detection methods
856#[derive(Debug, Clone, Copy, PartialEq, Eq)]
857#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
858pub enum EdgeDetectionMethod {
859    /// Sobel edge detection
860    Sobel,
861    /// Laplacian edge detection
862    Laplacian,
863    /// Canny edge detection (simplified)
864    Canny,
865}
866
867impl Default for EdgeDetectionMethod {
868    fn default() -> Self {
869        Self::Sobel
870    }
871}
872
873/// Edge detector for feature extraction
874#[derive(Debug, Clone)]
875#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
876pub struct EdgeDetector {
877    method: EdgeDetectionMethod,
878    threshold: Option<Float>,
879    blur_sigma: Option<Float>,
880}
881
882impl EdgeDetector {
883    /// Create a new edge detector
884    pub fn new() -> Self {
885        Self {
886            method: EdgeDetectionMethod::default(),
887            threshold: None,
888            blur_sigma: None,
889        }
890    }
891
892    /// Set edge detection method
893    pub fn with_method(mut self, method: EdgeDetectionMethod) -> Self {
894        self.method = method;
895        self
896    }
897
898    /// Set threshold for edge detection
899    pub fn with_threshold(mut self, threshold: Float) -> Self {
900        self.threshold = Some(threshold);
901        self
902    }
903
904    /// Set Gaussian blur sigma for preprocessing
905    pub fn with_blur_sigma(mut self, sigma: Float) -> Self {
906        self.blur_sigma = Some(sigma);
907        self
908    }
909}
910
911impl Transform<Array3<Float>, Array3<Float>> for EdgeDetector {
912    fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
913        // Convert to grayscale if needed
914        let gray_image = if x.dim().2 == 3 {
915            let color_transformer = ColorSpaceTransformer::new().from_rgb().to_grayscale();
916            color_transformer.transform(x)?
917        } else if x.dim().2 == 1 {
918            x.clone()
919        } else {
920            return Err(SklearsError::InvalidInput(
921                "Image must have 1 or 3 channels".to_string(),
922            ));
923        };
924
925        let mut processed = gray_image;
926
927        // Apply Gaussian blur if specified
928        if let Some(sigma) = self.blur_sigma {
929            processed = self.gaussian_blur(&processed, sigma)?;
930        }
931
932        // Apply edge detection
933        let edges = match self.method {
934            EdgeDetectionMethod::Sobel => self.sobel_edge_detection(&processed)?,
935            EdgeDetectionMethod::Laplacian => self.laplacian_edge_detection(&processed)?,
936            EdgeDetectionMethod::Canny => {
937                // Simplified Canny: Sobel + thresholding
938                let sobel_edges = self.sobel_edge_detection(&processed)?;
939                if let Some(threshold) = self.threshold {
940                    self.apply_threshold(&sobel_edges, threshold)?
941                } else {
942                    sobel_edges
943                }
944            }
945        };
946
947        Ok(edges)
948    }
949}
950
951impl EdgeDetector {
952    /// Apply Gaussian blur to reduce noise
953    fn gaussian_blur(&self, image: &Array3<Float>, sigma: Float) -> Result<Array3<Float>> {
954        let shape = image.dim();
955        let mut result = image.clone();
956
957        // Simple 3x3 Gaussian kernel approximation
958        let kernel_size = (6.0 * sigma).ceil() as usize + 1;
959        let kernel_radius = kernel_size / 2;
960
961        // Create Gaussian kernel
962        let mut kernel = vec![vec![0.0; kernel_size]; kernel_size];
963        let mut kernel_sum = 0.0;
964
965        for i in 0..kernel_size {
966            for j in 0..kernel_size {
967                let x = (i as isize - kernel_radius as isize) as Float;
968                let y = (j as isize - kernel_radius as isize) as Float;
969                let value = (-((x * x + y * y) / (2.0 * sigma * sigma))).exp();
970                kernel[i][j] = value;
971                kernel_sum += value;
972            }
973        }
974
975        // Normalize kernel
976        for i in 0..kernel_size {
977            for j in 0..kernel_size {
978                kernel[i][j] /= kernel_sum;
979            }
980        }
981
982        // Apply convolution
983        for row in kernel_radius..(shape.0 - kernel_radius) {
984            for col in kernel_radius..(shape.1 - kernel_radius) {
985                for channel in 0..shape.2 {
986                    let mut sum = 0.0;
987                    for ki in 0..kernel_size {
988                        for kj in 0..kernel_size {
989                            let img_row = row + ki - kernel_radius;
990                            let img_col = col + kj - kernel_radius;
991                            sum += image[[img_row, img_col, channel]] * kernel[ki][kj];
992                        }
993                    }
994                    result[[row, col, channel]] = sum;
995                }
996            }
997        }
998
999        Ok(result)
1000    }
1001
1002    /// Apply Sobel edge detection
1003    fn sobel_edge_detection(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
1004        let shape = image.dim();
1005        let mut result = Array3::zeros(shape);
1006
1007        // Sobel kernels
1008        let sobel_x: [[Float; 3]; 3] = [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]];
1009
1010        let sobel_y: [[Float; 3]; 3] = [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]];
1011
1012        // Apply Sobel operators
1013        for row in 1..(shape.0 - 1) {
1014            for col in 1..(shape.1 - 1) {
1015                for channel in 0..shape.2 {
1016                    let mut gx = 0.0;
1017                    let mut gy = 0.0;
1018
1019                    // Apply 3x3 kernels
1020                    for i in 0..3 {
1021                        for j in 0..3 {
1022                            let pixel_val = image[[row + i - 1, col + j - 1, channel]];
1023                            gx += pixel_val * sobel_x[i][j];
1024                            gy += pixel_val * sobel_y[i][j];
1025                        }
1026                    }
1027
1028                    // Calculate gradient magnitude
1029                    let gradient_magnitude = (gx * gx + gy * gy).sqrt();
1030                    result[[row, col, channel]] = gradient_magnitude;
1031                }
1032            }
1033        }
1034
1035        Ok(result)
1036    }
1037
1038    /// Apply Laplacian edge detection
1039    fn laplacian_edge_detection(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
1040        let shape = image.dim();
1041        let mut result = Array3::zeros(shape);
1042
1043        // Laplacian kernel
1044        let laplacian: [[Float; 3]; 3] = [[0.0, -1.0, 0.0], [-1.0, 4.0, -1.0], [0.0, -1.0, 0.0]];
1045
1046        // Apply Laplacian operator
1047        for row in 1..(shape.0 - 1) {
1048            for col in 1..(shape.1 - 1) {
1049                for channel in 0..shape.2 {
1050                    let mut sum = 0.0;
1051
1052                    // Apply 3x3 kernel
1053                    for i in 0..3 {
1054                        for j in 0..3 {
1055                            let pixel_val = image[[row + i - 1, col + j - 1, channel]];
1056                            sum += pixel_val * laplacian[i][j];
1057                        }
1058                    }
1059
1060                    result[[row, col, channel]] = sum.abs();
1061                }
1062            }
1063        }
1064
1065        Ok(result)
1066    }
1067
1068    /// Apply threshold to edge detection results
1069    fn apply_threshold(&self, image: &Array3<Float>, threshold: Float) -> Result<Array3<Float>> {
1070        let mut result = image.clone();
1071        result.mapv_inplace(|x| if x > threshold { 1.0 } else { 0.0 });
1072        Ok(result)
1073    }
1074}
1075
1076/// Basic feature extractor for images
1077#[derive(Debug, Clone)]
1078#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1079pub struct ImageFeatureExtractor {
1080    extract_edges: bool,
1081    extract_histograms: bool,
1082    histogram_bins: usize,
1083    extract_moments: bool,
1084}
1085
1086impl ImageFeatureExtractor {
1087    /// Create a new image feature extractor
1088    pub fn new() -> Self {
1089        Self {
1090            extract_edges: true,
1091            extract_histograms: true,
1092            histogram_bins: 32,
1093            extract_moments: true,
1094        }
1095    }
1096
1097    /// Enable/disable edge feature extraction
1098    pub fn with_edge_features(mut self, enabled: bool) -> Self {
1099        self.extract_edges = enabled;
1100        self
1101    }
1102
1103    /// Enable/disable histogram feature extraction
1104    pub fn with_histogram_features(mut self, enabled: bool, bins: usize) -> Self {
1105        self.extract_histograms = enabled;
1106        self.histogram_bins = bins;
1107        self
1108    }
1109
1110    /// Enable/disable moment feature extraction
1111    pub fn with_moment_features(mut self, enabled: bool) -> Self {
1112        self.extract_moments = enabled;
1113        self
1114    }
1115}
1116
1117impl Transform<Array3<Float>, Vec<Float>> for ImageFeatureExtractor {
1118    fn transform(&self, x: &Array3<Float>) -> Result<Vec<Float>> {
1119        let mut features = Vec::new();
1120
1121        // Extract edge features
1122        if self.extract_edges {
1123            let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1124            let edges = edge_detector.transform(x)?;
1125
1126            // Edge density (percentage of edge pixels)
1127            let total_pixels = edges.len();
1128            let edge_pixels = edges.iter().filter(|&&x| x > 0.1).count();
1129            features.push(edge_pixels as Float / total_pixels as Float);
1130
1131            // Mean edge strength
1132            let mean_edge_strength = edges.iter().sum::<Float>() / total_pixels as Float;
1133            features.push(mean_edge_strength);
1134        }
1135
1136        // Extract histogram features
1137        if self.extract_histograms {
1138            for channel in 0..x.dim().2 {
1139                let channel_data = x.index_axis(Axis(2), channel);
1140                let histogram = self.compute_histogram(&channel_data, self.histogram_bins)?;
1141                features.extend(histogram);
1142            }
1143        }
1144
1145        // Extract moment features
1146        if self.extract_moments {
1147            for channel in 0..x.dim().2 {
1148                let channel_data = x.index_axis(Axis(2), channel);
1149                let data_vec: Vec<Float> = channel_data.iter().copied().collect();
1150
1151                // First moment (mean)
1152                let mean = data_vec.iter().sum::<Float>() / data_vec.len() as Float;
1153                features.push(mean);
1154
1155                // Second moment (variance)
1156                let variance = data_vec.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
1157                    / data_vec.len() as Float;
1158                features.push(variance);
1159
1160                // Third moment (skewness approximation)
1161                let skewness = data_vec.iter().map(|&x| (x - mean).powi(3)).sum::<Float>()
1162                    / (data_vec.len() as Float * variance.powf(1.5));
1163                features.push(skewness);
1164
1165                // Fourth moment (kurtosis approximation)
1166                let kurtosis = data_vec.iter().map(|&x| (x - mean).powi(4)).sum::<Float>()
1167                    / (data_vec.len() as Float * variance.powi(2));
1168                features.push(kurtosis);
1169            }
1170        }
1171
1172        Ok(features)
1173    }
1174}
1175
1176impl ImageFeatureExtractor {
1177    /// Compute histogram for a 2D array
1178    fn compute_histogram(
1179        &self,
1180        data: &scirs2_core::ndarray::ArrayView2<Float>,
1181        bins: usize,
1182    ) -> Result<Vec<Float>> {
1183        let min_val = data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
1184        let max_val = data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
1185
1186        if (max_val - min_val).abs() < 1e-10 {
1187            return Ok(vec![0.0; bins]);
1188        }
1189
1190        let mut histogram = vec![0.0; bins];
1191        let bin_width = (max_val - min_val) / bins as Float;
1192
1193        for &value in data.iter() {
1194            let bin_index = ((value - min_val) / bin_width).floor() as usize;
1195            let bin_index = bin_index.min(bins - 1);
1196            histogram[bin_index] += 1.0;
1197        }
1198
1199        // Normalize histogram
1200        let total_count = data.len() as Float;
1201        for bin in &mut histogram {
1202            *bin /= total_count;
1203        }
1204
1205        Ok(histogram)
1206    }
1207}
1208
1209#[allow(non_snake_case)]
1210#[cfg(test)]
1211mod tests {
1212    use super::*;
1213    use approx::assert_abs_diff_eq;
1214    use scirs2_core::ndarray::arr3;
1215
1216    #[test]
1217    fn test_image_normalizer_minmax() -> Result<()> {
1218        let image = arr3(&[
1219            [[100.0, 50.0, 200.0], [150.0, 75.0, 250.0]],
1220            [[200.0, 100.0, 255.0], [50.0, 25.0, 100.0]],
1221        ]);
1222
1223        let normalizer = ImageNormalizer::new()
1224            .with_strategy(ImageNormalizationStrategy::MinMax)
1225            .with_channel_wise(true);
1226
1227        let fitted = normalizer.fit(&image, &())?;
1228        let normalized = fitted.transform(&image)?;
1229
1230        // Check that each channel is normalized to [0, 1]
1231        for channel in 0..3 {
1232            let channel_data = normalized.index_axis(Axis(2), channel);
1233            let min_val = channel_data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
1234            let max_val = channel_data
1235                .iter()
1236                .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
1237
1238            assert_abs_diff_eq!(min_val, 0.0, epsilon = 1e-10);
1239            assert_abs_diff_eq!(max_val, 1.0, epsilon = 1e-10);
1240        }
1241
1242        Ok(())
1243    }
1244
1245    #[test]
1246    fn test_image_normalizer_standard_score() -> Result<()> {
1247        let image = arr3(&[
1248            [[100.0, 50.0, 200.0], [150.0, 75.0, 250.0]],
1249            [[200.0, 100.0, 255.0], [50.0, 25.0, 100.0]],
1250        ]);
1251
1252        let normalizer = ImageNormalizer::new()
1253            .with_strategy(ImageNormalizationStrategy::StandardScore)
1254            .with_channel_wise(true);
1255
1256        let fitted = normalizer.fit(&image, &())?;
1257        let normalized = fitted.transform(&image)?;
1258
1259        // Check that each channel is standardized (approximately zero mean, unit std)
1260        for channel in 0..3 {
1261            let channel_data = normalized.index_axis(Axis(2), channel);
1262            let data_vec: Vec<Float> = channel_data.iter().copied().collect();
1263
1264            let mean = data_vec.iter().sum::<Float>() / data_vec.len() as Float;
1265            let std = (data_vec.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
1266                / (data_vec.len() - 1) as Float)
1267                .sqrt();
1268
1269            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
1270            assert_abs_diff_eq!(std, 1.0, epsilon = 1e-10);
1271        }
1272
1273        Ok(())
1274    }
1275
1276    #[test]
1277    fn test_image_augmenter_horizontal_flip() -> Result<()> {
1278        let image = arr3(&[
1279            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1280            [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
1281        ]);
1282
1283        let augmenter = ImageAugmenter::new()
1284            .with_horizontal_flip(true)
1285            .with_random_seed(42); // For deterministic testing
1286
1287        let flipped = augmenter.horizontal_flip(&image)?;
1288
1289        // Check that columns are flipped
1290        assert_eq!(flipped[[0, 0, 0]], image[[0, 1, 0]]);
1291        assert_eq!(flipped[[0, 1, 0]], image[[0, 0, 0]]);
1292
1293        Ok(())
1294    }
1295
1296    #[test]
1297    fn test_color_space_rgb_to_hsv() -> Result<()> {
1298        let rgb_image = arr3(&[
1299            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], // Red, Green
1300            [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], // Blue, White
1301        ]);
1302
1303        let transformer = ColorSpaceTransformer::new().from_rgb().to_hsv();
1304
1305        let hsv_image = transformer.transform(&rgb_image)?;
1306
1307        // Red should have H=0, S=1, V=1
1308        assert_abs_diff_eq!(hsv_image[[0, 0, 0]], 0.0, epsilon = 1e-6);
1309        assert_abs_diff_eq!(hsv_image[[0, 0, 1]], 1.0, epsilon = 1e-6);
1310        assert_abs_diff_eq!(hsv_image[[0, 0, 2]], 1.0, epsilon = 1e-6);
1311
1312        // White should have S=0, V=1
1313        assert_abs_diff_eq!(hsv_image[[1, 1, 1]], 0.0, epsilon = 1e-6);
1314        assert_abs_diff_eq!(hsv_image[[1, 1, 2]], 1.0, epsilon = 1e-6);
1315
1316        Ok(())
1317    }
1318
1319    #[test]
1320    fn test_rgb_to_grayscale() -> Result<()> {
1321        let rgb_image = arr3(&[
1322            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], // Red, Green
1323            [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], // Blue, White
1324        ]);
1325
1326        let transformer = ColorSpaceTransformer::new().from_rgb().to_grayscale();
1327
1328        let gray_image = transformer.transform(&rgb_image)?;
1329
1330        // Check output shape
1331        assert_eq!(gray_image.dim().2, 1);
1332
1333        // Red should be approximately 0.299
1334        assert_abs_diff_eq!(gray_image[[0, 0, 0]], 0.299, epsilon = 1e-6);
1335
1336        // Green should be approximately 0.587
1337        assert_abs_diff_eq!(gray_image[[0, 1, 0]], 0.587, epsilon = 1e-6);
1338
1339        // Blue should be approximately 0.114
1340        assert_abs_diff_eq!(gray_image[[1, 0, 0]], 0.114, epsilon = 1e-6);
1341
1342        // White should be 1.0
1343        assert_abs_diff_eq!(gray_image[[1, 1, 0]], 1.0, epsilon = 1e-6);
1344
1345        Ok(())
1346    }
1347
1348    #[test]
1349    fn test_image_resizer_nearest() -> Result<()> {
1350        let image = arr3(&[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
1351
1352        let resizer = ImageResizer::new((4, 4)).with_method(InterpolationMethod::Nearest);
1353
1354        let resized = resizer.transform(&image)?;
1355
1356        assert_eq!(resized.dim(), (4, 4, 2));
1357
1358        // Check some values (nearest neighbor should replicate pixels)
1359        assert_eq!(resized[[0, 0, 0]], image[[0, 0, 0]]);
1360        assert_eq!(resized[[3, 3, 0]], image[[1, 1, 0]]);
1361
1362        Ok(())
1363    }
1364
1365    #[test]
1366    fn test_image_resizer_bilinear() -> Result<()> {
1367        let image = arr3(&[[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]]);
1368
1369        let resizer = ImageResizer::new((3, 3)).with_method(InterpolationMethod::Bilinear);
1370
1371        let resized = resizer.transform(&image)?;
1372
1373        assert_eq!(resized.dim(), (3, 3, 2));
1374
1375        // Center pixel should be interpolated
1376        assert!(resized[[1, 1, 0]] > 0.0 && resized[[1, 1, 0]] < 1.0);
1377
1378        Ok(())
1379    }
1380
1381    #[test]
1382    fn test_edge_detector_sobel() -> Result<()> {
1383        // Create a larger test image with clear edges (RGB format - 4x4)
1384        let image = arr3(&[
1385            [
1386                [0.0, 0.0, 0.0],
1387                [0.0, 0.0, 0.0],
1388                [1.0, 1.0, 1.0],
1389                [1.0, 1.0, 1.0],
1390            ],
1391            [
1392                [0.0, 0.0, 0.0],
1393                [0.0, 0.0, 0.0],
1394                [1.0, 1.0, 1.0],
1395                [1.0, 1.0, 1.0],
1396            ],
1397            [
1398                [0.0, 0.0, 0.0],
1399                [0.0, 0.0, 0.0],
1400                [1.0, 1.0, 1.0],
1401                [1.0, 1.0, 1.0],
1402            ],
1403            [
1404                [0.0, 0.0, 0.0],
1405                [0.0, 0.0, 0.0],
1406                [1.0, 1.0, 1.0],
1407                [1.0, 1.0, 1.0],
1408            ],
1409        ]);
1410
1411        let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1412
1413        let edges = edge_detector.transform(&image)?;
1414
1415        // Check output shape (should be grayscale)
1416        assert_eq!(edges.dim().2, 1);
1417
1418        // Edges should have been detected (use lower threshold for small gradients)
1419        let max_edge = edges.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1420        assert!(
1421            max_edge > 0.01,
1422            "Expected edge detection to produce values > 0.01, got max: {}",
1423            max_edge
1424        );
1425
1426        Ok(())
1427    }
1428
1429    #[test]
1430    fn test_edge_detector_laplacian() -> Result<()> {
1431        // Create a larger test image with clear edges (4x4)
1432        let image = arr3(&[
1433            [
1434                [0.0, 0.0, 0.0],
1435                [0.0, 0.0, 0.0],
1436                [1.0, 1.0, 1.0],
1437                [1.0, 1.0, 1.0],
1438            ],
1439            [
1440                [0.0, 0.0, 0.0],
1441                [0.0, 0.0, 0.0],
1442                [1.0, 1.0, 1.0],
1443                [1.0, 1.0, 1.0],
1444            ],
1445            [
1446                [0.0, 0.0, 0.0],
1447                [0.0, 0.0, 0.0],
1448                [1.0, 1.0, 1.0],
1449                [1.0, 1.0, 1.0],
1450            ],
1451            [
1452                [0.0, 0.0, 0.0],
1453                [0.0, 0.0, 0.0],
1454                [1.0, 1.0, 1.0],
1455                [1.0, 1.0, 1.0],
1456            ],
1457        ]);
1458
1459        let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Laplacian);
1460
1461        let edges = edge_detector.transform(&image)?;
1462
1463        // Check output shape
1464        assert_eq!(edges.dim().2, 1);
1465
1466        // Should detect edges (use lower threshold for small gradients)
1467        let max_edge = edges.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1468        assert!(
1469            max_edge > 0.01,
1470            "Expected edge detection to produce values > 0.01, got max: {}",
1471            max_edge
1472        );
1473
1474        Ok(())
1475    }
1476
1477    #[test]
1478    fn test_edge_detector_with_threshold() -> Result<()> {
1479        let image = arr3(&[
1480            [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]],
1481            [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]],
1482        ]);
1483
1484        let edge_detector = EdgeDetector::new()
1485            .with_method(EdgeDetectionMethod::Canny)
1486            .with_threshold(0.3);
1487
1488        let edges = edge_detector.transform(&image)?;
1489
1490        // Check output shape
1491        assert_eq!(edges.dim().2, 1);
1492
1493        // Values should be binary (0.0 or 1.0) due to thresholding
1494        let all_binary = edges.iter().all(|&x| x == 0.0 || x == 1.0);
1495        assert!(all_binary);
1496
1497        Ok(())
1498    }
1499
1500    #[test]
1501    fn test_image_feature_extractor() -> Result<()> {
1502        let image = arr3(&[
1503            [[0.0, 0.5, 1.0], [0.2, 0.7, 0.9]],
1504            [[0.1, 0.6, 0.8], [0.3, 0.4, 0.6]],
1505        ]);
1506
1507        let feature_extractor = ImageFeatureExtractor::new()
1508            .with_edge_features(true)
1509            .with_histogram_features(true, 4)
1510            .with_moment_features(true);
1511
1512        let features = feature_extractor.transform(&image)?;
1513
1514        // Should extract features
1515        assert!(!features.is_empty());
1516
1517        // Should have edge features (2), histogram features (4 bins * 3 channels = 12),
1518        // and moment features (4 moments * 3 channels = 12)
1519        // Total: 2 + 12 + 12 = 26 features
1520        assert_eq!(features.len(), 26);
1521
1522        // All features should be finite
1523        assert!(features.iter().all(|&x| x.is_finite()));
1524
1525        Ok(())
1526    }
1527
1528    #[test]
1529    fn test_image_feature_extractor_selective_features() -> Result<()> {
1530        let image = arr3(&[
1531            [[0.0, 0.5, 0.2], [0.2, 0.7, 0.1]],
1532            [[0.1, 0.6, 0.3], [0.3, 0.4, 0.2]],
1533        ]);
1534
1535        // Only extract edge features
1536        let feature_extractor = ImageFeatureExtractor::new()
1537            .with_edge_features(true)
1538            .with_histogram_features(false, 4)
1539            .with_moment_features(false);
1540
1541        let features = feature_extractor.transform(&image)?;
1542
1543        // Should only have 2 edge features
1544        assert_eq!(features.len(), 2);
1545
1546        // Features should be meaningful (non-negative for density and strength)
1547        assert!(features[0] >= 0.0); // Edge density
1548        assert!(features[1] >= 0.0); // Mean edge strength
1549
1550        Ok(())
1551    }
1552
1553    #[test]
1554    fn test_gaussian_blur() -> Result<()> {
1555        // Create a larger test image with clear edges for blur testing (6x6)
1556        let image = arr3(&[
1557            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1558            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1559            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1560            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1561            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1562            [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1563        ]);
1564
1565        // Test blur indirectly through edge detection with blur preprocessing
1566        let edge_detector_without_blur =
1567            EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1568        let edge_detector_with_blur = EdgeDetector::new()
1569            .with_method(EdgeDetectionMethod::Sobel)
1570            .with_blur_sigma(2.0);
1571
1572        let edges_without_blur = edge_detector_without_blur.transform(&image)?;
1573        let edges_with_blur = edge_detector_with_blur.transform(&image)?;
1574
1575        // Count non-zero edge pixels
1576        let edge_count_without_blur = edges_without_blur.iter().filter(|&&x| x > 0.01).count();
1577        let edge_count_with_blur = edges_with_blur.iter().filter(|&&x| x > 0.01).count();
1578
1579        // Blur should reduce the number of detected edges OR reduce their strength
1580        let max_edge_without_blur = edges_without_blur
1581            .iter()
1582            .fold(0.0_f64, |acc, &x| acc.max(x));
1583        let max_edge_with_blur = edges_with_blur.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1584
1585        // At least one of these should be true: fewer edges detected OR weaker max edge strength
1586        assert!(
1587            edge_count_with_blur <= edge_count_without_blur
1588                || max_edge_with_blur <= max_edge_without_blur,
1589            "Expected blur to reduce edge count ({} vs {}) or max edge strength ({:.6} vs {:.6})",
1590            edge_count_with_blur,
1591            edge_count_without_blur,
1592            max_edge_with_blur,
1593            max_edge_without_blur
1594        );
1595
1596        Ok(())
1597    }
1598}